Internal C API headers.
Change: 153717143
This commit is contained in:
parent
7933420df6
commit
79d6be3b86
@ -26,6 +26,22 @@ filegroup(
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_internal",
|
||||
srcs = ["c_api.h"],
|
||||
hdrs = ["c_api_internal.h"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api",
|
||||
srcs = ["c_api.cc"],
|
||||
@ -34,9 +50,11 @@ tf_cuda_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
":c_api_internal",
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":c_api_internal",
|
||||
"//tensorflow/cc/saved_model:loader",
|
||||
"//tensorflow/cc:gradients",
|
||||
"//tensorflow/cc:ops",
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/framework/scope_internal.h"
|
||||
#include "tensorflow/cc/saved_model/loader.h"
|
||||
#endif
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/log_memory.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
@ -96,9 +97,6 @@ size_t TF_DataTypeSize(TF_DataType dt) {
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
struct TF_Status {
|
||||
Status status;
|
||||
};
|
||||
|
||||
TF_Status* TF_NewStatus() { return new TF_Status; }
|
||||
|
||||
@ -182,12 +180,6 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in,
|
||||
|
||||
} // namespace
|
||||
|
||||
struct TF_Tensor {
|
||||
TF_DataType dtype;
|
||||
TensorShape shape;
|
||||
TensorBuffer* buffer;
|
||||
};
|
||||
|
||||
TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
|
||||
int num_dims, size_t len) {
|
||||
void* data = allocate_tensor("TF_AllocateTensor", len);
|
||||
@ -292,9 +284,6 @@ size_t TF_StringEncodedSize(size_t len) {
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
struct TF_SessionOptions {
|
||||
SessionOptions options;
|
||||
};
|
||||
TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
|
||||
void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
|
||||
|
||||
@ -335,9 +324,6 @@ void TF_DeleteBuffer(TF_Buffer* buffer) {
|
||||
TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; }
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
struct TF_DeprecatedSession {
|
||||
Session* session;
|
||||
};
|
||||
|
||||
TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
|
||||
TF_Status* status) {
|
||||
@ -701,11 +687,6 @@ void TF_PRun(TF_DeprecatedSession* s, const char* handle,
|
||||
c_outputs, target_oper_names, nullptr, status);
|
||||
}
|
||||
|
||||
struct TF_Library {
|
||||
void* lib_handle;
|
||||
TF_Buffer op_list;
|
||||
};
|
||||
|
||||
TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
|
||||
TF_Library* lib_handle = new TF_Library;
|
||||
status->status = tensorflow::LoadLibrary(
|
||||
@ -742,66 +723,6 @@ TF_Buffer* TF_GetAllOpList() {
|
||||
// --------------------------------------------------------------------------
|
||||
// New Graph and Session API
|
||||
|
||||
// Structures -----------------------------------------------------------------
|
||||
|
||||
extern "C" {
|
||||
|
||||
struct TF_Graph {
|
||||
TF_Graph()
|
||||
: graph(OpRegistry::Global()),
|
||||
refiner(graph.versions().producer(), graph.op_registry()),
|
||||
num_sessions(0),
|
||||
delete_requested(false),
|
||||
parent(nullptr),
|
||||
parent_inputs(nullptr) {}
|
||||
mutex mu;
|
||||
Graph graph GUARDED_BY(mu);
|
||||
|
||||
// Runs shape inference.
|
||||
tensorflow::ShapeRefiner refiner GUARDED_BY(mu);
|
||||
|
||||
// Maps from name of an operation to the Node* in 'graph'.
|
||||
std::unordered_map<tensorflow::string, Node*> name_map GUARDED_BY(mu);
|
||||
|
||||
// TF_Graph may only / must be deleted when
|
||||
// num_sessions == 0 && delete_requested == true
|
||||
|
||||
// num_sessions incremented by TF_NewSession, and decremented by
|
||||
// TF_DeleteSession.
|
||||
int num_sessions GUARDED_BY(mu);
|
||||
bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph
|
||||
|
||||
// Used to link graphs contained in TF_WhileParams to the parent graph that
|
||||
// will eventually contain the full while loop.
|
||||
TF_Graph* parent;
|
||||
TF_Output* parent_inputs;
|
||||
};
|
||||
|
||||
struct TF_OperationDescription {
|
||||
TF_OperationDescription(TF_Graph* g, const char* op_type,
|
||||
const char* node_name)
|
||||
: node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {}
|
||||
|
||||
NodeBuilder node_builder;
|
||||
TF_Graph* graph;
|
||||
std::vector<tensorflow::string> colocation_constraints;
|
||||
};
|
||||
|
||||
struct TF_Operation {
|
||||
Node node;
|
||||
};
|
||||
|
||||
struct TF_Session {
|
||||
TF_Session(Session* s, TF_Graph* g)
|
||||
: session(s), graph(g), last_num_graph_nodes(0) {}
|
||||
Session* session;
|
||||
TF_Graph* graph;
|
||||
mutex mu;
|
||||
int last_num_graph_nodes;
|
||||
};
|
||||
|
||||
} // end extern "C"
|
||||
|
||||
// Helper functions -----------------------------------------------------------
|
||||
|
||||
namespace {
|
||||
@ -1691,10 +1612,6 @@ void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
|
||||
status->status = MessageToBuffer(def, output_graph_def);
|
||||
}
|
||||
|
||||
struct TF_ImportGraphDefOptions {
|
||||
tensorflow::ImportGraphDefOptions opts;
|
||||
};
|
||||
|
||||
TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
|
||||
return new TF_ImportGraphDefOptions;
|
||||
}
|
||||
|
116
tensorflow/c/c_api_internal.h
Normal file
116
tensorflow/c/c_api_internal.h
Normal file
@ -0,0 +1,116 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
|
||||
|
||||
// Internal structures used by the C API. These are likely to change and should
|
||||
// not be depended on.
|
||||
|
||||
struct TF_Status {
|
||||
tensorflow::Status status;
|
||||
};
|
||||
|
||||
struct TF_Tensor {
|
||||
TF_DataType dtype;
|
||||
tensorflow::TensorShape shape;
|
||||
tensorflow::TensorBuffer* buffer;
|
||||
};
|
||||
|
||||
struct TF_SessionOptions {
|
||||
tensorflow::SessionOptions options;
|
||||
};
|
||||
|
||||
struct TF_DeprecatedSession {
|
||||
tensorflow::Session* session;
|
||||
};
|
||||
|
||||
struct TF_Library {
|
||||
void* lib_handle;
|
||||
TF_Buffer op_list;
|
||||
};
|
||||
|
||||
struct TF_Graph {
|
||||
TF_Graph()
|
||||
: graph(tensorflow::OpRegistry::Global()),
|
||||
refiner(graph.versions().producer(), graph.op_registry()),
|
||||
num_sessions(0),
|
||||
delete_requested(false),
|
||||
parent(nullptr),
|
||||
parent_inputs(nullptr) {}
|
||||
tensorflow::mutex mu;
|
||||
tensorflow::Graph graph GUARDED_BY(mu);
|
||||
|
||||
// Runs shape inference.
|
||||
tensorflow::ShapeRefiner refiner GUARDED_BY(mu);
|
||||
|
||||
// Maps from name of an operation to the Node* in 'graph'.
|
||||
std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
|
||||
GUARDED_BY(mu);
|
||||
|
||||
// TF_Graph may only / must be deleted when
|
||||
// num_sessions == 0 && delete_requested == true
|
||||
|
||||
// num_sessions incremented by TF_NewSession, and decremented by
|
||||
// TF_DeleteSession.
|
||||
int num_sessions GUARDED_BY(mu);
|
||||
bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph
|
||||
|
||||
// Used to link graphs contained in TF_WhileParams to the parent graph that
|
||||
// will eventually contain the full while loop.
|
||||
TF_Graph* parent;
|
||||
TF_Output* parent_inputs;
|
||||
};
|
||||
|
||||
struct TF_OperationDescription {
|
||||
TF_OperationDescription(TF_Graph* g, const char* op_type,
|
||||
const char* node_name)
|
||||
: node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {}
|
||||
|
||||
tensorflow::NodeBuilder node_builder;
|
||||
TF_Graph* graph;
|
||||
std::vector<tensorflow::string> colocation_constraints;
|
||||
};
|
||||
|
||||
struct TF_Operation {
|
||||
tensorflow::Node node;
|
||||
};
|
||||
|
||||
struct TF_Session {
|
||||
TF_Session(tensorflow::Session* s, TF_Graph* g)
|
||||
: session(s), graph(g), last_num_graph_nodes(0) {}
|
||||
tensorflow::Session* session;
|
||||
TF_Graph* graph;
|
||||
tensorflow::mutex mu;
|
||||
int last_num_graph_nodes;
|
||||
};
|
||||
|
||||
struct TF_ImportGraphDefOptions {
|
||||
tensorflow::ImportGraphDefOptions opts;
|
||||
};
|
Loading…
Reference in New Issue
Block a user