From 79d6be3b86ba32048cac6afd5d1c5d1bb8aee6d9 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Thu, 20 Apr 2017 08:18:28 -0800 Subject: [PATCH] Internal C API headers. Change: 153717143 --- tensorflow/c/BUILD | 18 ++++++ tensorflow/c/c_api.cc | 85 +------------------------ tensorflow/c/c_api_internal.h | 116 ++++++++++++++++++++++++++++++++++ 3 files changed, 135 insertions(+), 84 deletions(-) create mode 100644 tensorflow/c/c_api_internal.h diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index af96ce70b69..4ad69ae3fbd 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -26,6 +26,22 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) +tf_cuda_library( + name = "c_api_internal", + srcs = ["c_api.h"], + hdrs = ["c_api_internal.h"], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + }), +) + tf_cuda_library( name = "c_api", srcs = ["c_api.cc"], @@ -34,9 +50,11 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ + ":c_api_internal", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ + ":c_api_internal", "//tensorflow/cc/saved_model:loader", "//tensorflow/cc:gradients", "//tensorflow/cc:ops", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 7ef7e3c5404..0f66a47b4ad 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/saved_model/loader.h" #endif +#include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/node_def_util.h" @@ -96,9 +97,6 @@ size_t TF_DataTypeSize(TF_DataType dt) { } // -------------------------------------------------------------------------- -struct TF_Status { - Status status; -}; TF_Status* TF_NewStatus() { return new TF_Status; } @@ -182,12 +180,6 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, } // namespace -struct TF_Tensor { - TF_DataType dtype; - TensorShape shape; - TensorBuffer* buffer; -}; - TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims, int num_dims, size_t len) { void* data = allocate_tensor("TF_AllocateTensor", len); @@ -292,9 +284,6 @@ size_t TF_StringEncodedSize(size_t len) { } // -------------------------------------------------------------------------- -struct TF_SessionOptions { - SessionOptions options; -}; TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; } void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; } @@ -335,9 +324,6 @@ void TF_DeleteBuffer(TF_Buffer* buffer) { TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; } // -------------------------------------------------------------------------- -struct TF_DeprecatedSession { - Session* session; -}; TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt, TF_Status* status) { @@ -701,11 +687,6 @@ void TF_PRun(TF_DeprecatedSession* s, const char* handle, c_outputs, target_oper_names, nullptr, status); } -struct TF_Library { - void* lib_handle; - TF_Buffer op_list; -}; - TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) { TF_Library* lib_handle = new TF_Library; status->status = tensorflow::LoadLibrary( @@ -742,66 +723,6 @@ TF_Buffer* TF_GetAllOpList() { // -------------------------------------------------------------------------- // New Graph and Session API -// Structures ----------------------------------------------------------------- - -extern "C" { - -struct TF_Graph { - TF_Graph() - : graph(OpRegistry::Global()), - refiner(graph.versions().producer(), graph.op_registry()), - num_sessions(0), - delete_requested(false), - parent(nullptr), - parent_inputs(nullptr) {} - mutex mu; - Graph graph GUARDED_BY(mu); - - // Runs shape inference. - tensorflow::ShapeRefiner refiner GUARDED_BY(mu); - - // Maps from name of an operation to the Node* in 'graph'. - std::unordered_map name_map GUARDED_BY(mu); - - // TF_Graph may only / must be deleted when - // num_sessions == 0 && delete_requested == true - - // num_sessions incremented by TF_NewSession, and decremented by - // TF_DeleteSession. - int num_sessions GUARDED_BY(mu); - bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph - - // Used to link graphs contained in TF_WhileParams to the parent graph that - // will eventually contain the full while loop. - TF_Graph* parent; - TF_Output* parent_inputs; -}; - -struct TF_OperationDescription { - TF_OperationDescription(TF_Graph* g, const char* op_type, - const char* node_name) - : node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {} - - NodeBuilder node_builder; - TF_Graph* graph; - std::vector colocation_constraints; -}; - -struct TF_Operation { - Node node; -}; - -struct TF_Session { - TF_Session(Session* s, TF_Graph* g) - : session(s), graph(g), last_num_graph_nodes(0) {} - Session* session; - TF_Graph* graph; - mutex mu; - int last_num_graph_nodes; -}; - -} // end extern "C" - // Helper functions ----------------------------------------------------------- namespace { @@ -1691,10 +1612,6 @@ void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def, status->status = MessageToBuffer(def, output_graph_def); } -struct TF_ImportGraphDefOptions { - tensorflow::ImportGraphDefOptions opts; -}; - TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() { return new TF_ImportGraphDefOptions; } diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h new file mode 100644 index 00000000000..b5320d20dad --- /dev/null +++ b/tensorflow/c/c_api_internal.h @@ -0,0 +1,116 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/c_api.h" + +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" + + +// Internal structures used by the C API. These are likely to change and should +// not be depended on. + +struct TF_Status { + tensorflow::Status status; +}; + +struct TF_Tensor { + TF_DataType dtype; + tensorflow::TensorShape shape; + tensorflow::TensorBuffer* buffer; +}; + +struct TF_SessionOptions { + tensorflow::SessionOptions options; +}; + +struct TF_DeprecatedSession { + tensorflow::Session* session; +}; + +struct TF_Library { + void* lib_handle; + TF_Buffer op_list; +}; + +struct TF_Graph { + TF_Graph() + : graph(tensorflow::OpRegistry::Global()), + refiner(graph.versions().producer(), graph.op_registry()), + num_sessions(0), + delete_requested(false), + parent(nullptr), + parent_inputs(nullptr) {} + tensorflow::mutex mu; + tensorflow::Graph graph GUARDED_BY(mu); + + // Runs shape inference. + tensorflow::ShapeRefiner refiner GUARDED_BY(mu); + + // Maps from name of an operation to the Node* in 'graph'. + std::unordered_map name_map + GUARDED_BY(mu); + + // TF_Graph may only / must be deleted when + // num_sessions == 0 && delete_requested == true + + // num_sessions incremented by TF_NewSession, and decremented by + // TF_DeleteSession. + int num_sessions GUARDED_BY(mu); + bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph + + // Used to link graphs contained in TF_WhileParams to the parent graph that + // will eventually contain the full while loop. + TF_Graph* parent; + TF_Output* parent_inputs; +}; + +struct TF_OperationDescription { + TF_OperationDescription(TF_Graph* g, const char* op_type, + const char* node_name) + : node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {} + + tensorflow::NodeBuilder node_builder; + TF_Graph* graph; + std::vector colocation_constraints; +}; + +struct TF_Operation { + tensorflow::Node node; +}; + +struct TF_Session { + TF_Session(tensorflow::Session* s, TF_Graph* g) + : session(s), graph(g), last_num_graph_nodes(0) {} + tensorflow::Session* session; + TF_Graph* graph; + tensorflow::mutex mu; + int last_num_graph_nodes; +}; + +struct TF_ImportGraphDefOptions { + tensorflow::ImportGraphDefOptions opts; +};