2569 lines
93 KiB
C++
2569 lines
93 KiB
C++
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/c/c_api.h"
|
|
|
|
#include <algorithm>
|
|
#include <limits>
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
#include "absl/strings/match.h"
|
|
// Required for IS_MOBILE_PLATFORM
|
|
#include "tensorflow/core/platform/platform.h" // NOLINT
|
|
|
|
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
|
#include "tensorflow/cc/framework/gradients.h"
|
|
#include "tensorflow/cc/framework/ops.h"
|
|
#include "tensorflow/cc/framework/scope_internal.h"
|
|
#include "tensorflow/cc/ops/while_loop.h"
|
|
#include "tensorflow/cc/saved_model/loader.h"
|
|
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
|
#include "tensorflow/core/framework/logging.h"
|
|
#include "tensorflow/core/framework/op_gen_lib.h"
|
|
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
|
#include "tensorflow/c/c_api_internal.h"
|
|
#include "tensorflow/c/tf_status_internal.h"
|
|
#include "tensorflow/c/tf_tensor.h"
|
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
|
#include "tensorflow/core/common_runtime/eval_const_tensor.h"
|
|
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
|
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
|
#include "tensorflow/core/framework/allocation_description.pb.h"
|
|
#include "tensorflow/core/framework/kernel_def.pb.h"
|
|
#include "tensorflow/core/framework/log_memory.h"
|
|
#include "tensorflow/core/framework/node_def_util.h"
|
|
#include "tensorflow/core/framework/op_kernel.h"
|
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
|
#include "tensorflow/core/framework/tensor.h"
|
|
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
|
|
#include "tensorflow/core/framework/tensor_shape.h"
|
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
|
#include "tensorflow/core/framework/types.h"
|
|
#include "tensorflow/core/framework/versions.pb.h"
|
|
#include "tensorflow/core/graph/graph.h"
|
|
#include "tensorflow/core/graph/node_builder.h"
|
|
#include "tensorflow/core/graph/validate.h"
|
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
|
#include "tensorflow/core/platform/coding.h"
|
|
#include "tensorflow/core/platform/errors.h"
|
|
#include "tensorflow/core/platform/mem.h"
|
|
#include "tensorflow/core/platform/mutex.h"
|
|
#include "tensorflow/core/platform/protobuf.h"
|
|
#include "tensorflow/core/platform/status.h"
|
|
#include "tensorflow/core/platform/str_util.h"
|
|
#include "tensorflow/core/platform/strcat.h"
|
|
#include "tensorflow/core/platform/stringpiece.h"
|
|
#include "tensorflow/core/platform/thread_annotations.h"
|
|
#include "tensorflow/core/platform/types.h"
|
|
#include "tensorflow/core/public/session.h"
|
|
#include "tensorflow/core/public/version.h"
|
|
|
|
// The implementation below is at the top level instead of the
|
|
// brain namespace because we are defining 'extern "C"' functions.
|
|
using tensorflow::AllocationDescription;
|
|
using tensorflow::DataType;
|
|
using tensorflow::ExtendSessionGraphHelper;
|
|
using tensorflow::Graph;
|
|
using tensorflow::GraphDef;
|
|
using tensorflow::mutex_lock;
|
|
using tensorflow::NameRangeMap;
|
|
using tensorflow::NameRangesForNode;
|
|
using tensorflow::NewSession;
|
|
using tensorflow::Node;
|
|
using tensorflow::NodeBuilder;
|
|
using tensorflow::NodeDef;
|
|
using tensorflow::OpDef;
|
|
using tensorflow::OpRegistry;
|
|
using tensorflow::OutputTensor;
|
|
using tensorflow::PartialTensorShape;
|
|
using tensorflow::RunMetadata;
|
|
using tensorflow::RunOptions;
|
|
using tensorflow::Session;
|
|
using tensorflow::Status;
|
|
using tensorflow::string;
|
|
using tensorflow::Tensor;
|
|
using tensorflow::TensorBuffer;
|
|
using tensorflow::TensorId;
|
|
using tensorflow::TensorShape;
|
|
using tensorflow::TensorShapeProto;
|
|
using tensorflow::VersionDef;
|
|
using tensorflow::errors::FailedPrecondition;
|
|
using tensorflow::errors::InvalidArgument;
|
|
using tensorflow::gtl::ArraySlice;
|
|
using tensorflow::strings::StrCat;
|
|
|
|
extern "C" {
|
|
|
|
// --------------------------------------------------------------------------
|
|
const char* TF_Version() { return TF_VERSION_STRING; }
|
|
|
|
// --------------------------------------------------------------------------
|
|
|
|
// --------------------------------------------------------------------------
|
|
TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
|
|
void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
|
|
|
|
void TF_SetTarget(TF_SessionOptions* options, const char* target) {
|
|
options->options.target = target;
|
|
}
|
|
|
|
void TF_SetConfig(TF_SessionOptions* options, const void* proto,
|
|
size_t proto_len, TF_Status* status) {
|
|
if (!options->options.config.ParseFromArray(proto, proto_len)) {
|
|
status->status = InvalidArgument("Unparseable ConfigProto");
|
|
}
|
|
}
|
|
// --------------------------------------------------------------------------
|
|
TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; }
|
|
|
|
TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) {
|
|
void* copy = tensorflow::port::Malloc(proto_len);
|
|
memcpy(copy, proto, proto_len);
|
|
|
|
TF_Buffer* buf = new TF_Buffer;
|
|
buf->data = copy;
|
|
buf->length = proto_len;
|
|
buf->data_deallocator = [](void* data, size_t length) {
|
|
tensorflow::port::Free(data);
|
|
};
|
|
return buf;
|
|
}
|
|
|
|
void TF_DeleteBuffer(TF_Buffer* buffer) {
|
|
if (buffer == nullptr) return;
|
|
if (buffer->data_deallocator != nullptr) {
|
|
(*buffer->data_deallocator)(const_cast<void*>(buffer->data),
|
|
buffer->length);
|
|
}
|
|
delete buffer;
|
|
}
|
|
|
|
TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; }
|
|
|
|
// --------------------------------------------------------------------------
|
|
|
|
TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
|
|
TF_Status* status) {
|
|
Session* session;
|
|
status->status = NewSession(opt->options, &session);
|
|
if (status->status.ok()) {
|
|
return new TF_DeprecatedSession({session});
|
|
} else {
|
|
DCHECK_EQ(nullptr, session);
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
|
|
status->status = s->session->Close();
|
|
}
|
|
|
|
void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
|
|
status->status = Status::OK();
|
|
if (s == nullptr) return;
|
|
delete s->session;
|
|
delete s;
|
|
}
|
|
|
|
void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
|
|
size_t proto_len, TF_Status* status) {
|
|
GraphDef g;
|
|
if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
|
|
status->status = InvalidArgument("Invalid GraphDef");
|
|
return;
|
|
}
|
|
status->status = s->session->Extend(g);
|
|
}
|
|
|
|
} // end extern "C"
|
|
|
|
// Reset helper for converting character arrays to string vectors.
|
|
static void TF_Reset_Helper(const TF_SessionOptions* opt,
|
|
const char** containers, int ncontainers,
|
|
TF_Status* status) {
|
|
std::vector<string> container_names(ncontainers);
|
|
for (int i = 0; i < ncontainers; ++i) {
|
|
container_names[i] = containers[i];
|
|
}
|
|
|
|
status->status = Reset(opt->options, container_names);
|
|
}
|
|
|
|
extern "C" {
|
|
|
|
void TF_Reset(const TF_SessionOptions* opt, const char** containers,
|
|
int ncontainers, TF_Status* status) {
|
|
TF_Reset_Helper(opt, containers, ncontainers, status);
|
|
}
|
|
|
|
} // end extern "C"
|
|
|
|
namespace tensorflow {
|
|
|
|
|
|
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
|
|
TF_Buffer* out) {
|
|
if (out->data != nullptr) {
|
|
return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
|
|
}
|
|
const size_t proto_size = in.ByteSizeLong();
|
|
void* buf = port::Malloc(proto_size);
|
|
if (buf == nullptr) {
|
|
return tensorflow::errors::ResourceExhausted(
|
|
"Failed to allocate memory to serialize message of type '",
|
|
in.GetTypeName(), "' and size ", proto_size);
|
|
}
|
|
if (!in.SerializeWithCachedSizesToArray(static_cast<uint8*>(buf))) {
|
|
port::Free(buf);
|
|
return InvalidArgument("Unable to serialize ", in.GetTypeName(),
|
|
" protocol buffer, perhaps the serialized size (",
|
|
proto_size, " bytes) is too large?");
|
|
}
|
|
out->data = buf;
|
|
out->length = proto_size;
|
|
out->data_deallocator = [](void* data, size_t length) { port::Free(data); };
|
|
return Status::OK();
|
|
}
|
|
|
|
void RecordMutation(TF_Graph* graph, const TF_Operation& op,
|
|
const char* mutation_type) {
|
|
// If any session has already run this node_id, mark this session as
|
|
// unrunnable.
|
|
for (auto it : graph->sessions) {
|
|
mutex_lock session_lock(it.first->mu);
|
|
if (it.first->last_num_graph_nodes > op.node.id()) {
|
|
it.second = strings::StrCat(
|
|
"Operation '", op.node.DebugString(), "' was changed by ",
|
|
mutation_type,
|
|
" after it was run by a session. This mutation will have no effect, "
|
|
"and will trigger an error in the future. Either don't modify "
|
|
"nodes after running them or create a new session.");
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
|
|
// Helper method that creates a shape handle for a shape described by dims.
|
|
tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims(
|
|
tensorflow::shape_inference::InferenceContext* ic, int num_dims,
|
|
const int64_t* dims) {
|
|
if (num_dims != -1) {
|
|
std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
|
|
dim_vec.reserve(num_dims);
|
|
for (int i = 0; i < num_dims; ++i) {
|
|
dim_vec.push_back(ic->MakeDim(dims[i]));
|
|
}
|
|
return ic->MakeShape(dim_vec);
|
|
} else {
|
|
return ic->UnknownShape();
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
|
|
int num_shapes_and_types,
|
|
const int64_t** shapes,
|
|
const int* ranks,
|
|
const TF_DataType* types,
|
|
TF_Status* status) {
|
|
Node* node = &output.oper->node;
|
|
|
|
mutex_lock l(graph->mu);
|
|
tensorflow::shape_inference::InferenceContext* ic =
|
|
graph->refiner.GetContext(node);
|
|
if (ic == nullptr) {
|
|
status->status =
|
|
InvalidArgument("Node ", node->name(), " was not found in the graph");
|
|
return;
|
|
}
|
|
|
|
auto shape_and_type_vec =
|
|
std::vector<tensorflow::shape_inference::ShapeAndType>(
|
|
num_shapes_and_types);
|
|
for (int i = 0; i < num_shapes_and_types; ++i) {
|
|
tensorflow::shape_inference::ShapeHandle shape_handle =
|
|
ShapeHandleFromDims(ic, ranks[i], shapes[i]);
|
|
shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType(
|
|
shape_handle, static_cast<DataType>(types[i]));
|
|
}
|
|
|
|
ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec);
|
|
}
|
|
|
|
// Helpers for loading a TensorFlow plugin (a .so file).
|
|
Status LoadLibrary(const char* library_filename, void** result,
|
|
const void** buf, size_t* len);
|
|
|
|
// TODO(josh11b,mrry): Change Session to be able to use a Graph*
|
|
// directly, instead of requiring us to serialize to a GraphDef and
|
|
// call Session::Extend().
|
|
bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
|
|
if (session->graph != nullptr) {
|
|
// Take the graph lock before the session lock to avoid deadlock. This is
|
|
// safe since session->graph does not change.
|
|
session->graph->mu.lock();
|
|
mutex_lock session_lock(session->mu);
|
|
const Graph& graph = session->graph->graph;
|
|
|
|
const string& mutation_warning = session->graph->sessions[session];
|
|
if (!mutation_warning.empty()) {
|
|
// TODO(b/74949947): turn this back into an error status
|
|
LOG(WARNING) << mutation_warning;
|
|
session->graph->sessions[session].clear();
|
|
}
|
|
|
|
const auto num_nodes = graph.num_node_ids();
|
|
if (session->last_num_graph_nodes < num_nodes) {
|
|
// TODO(nolivia): check this on a subset of the graph instead of all of
|
|
// it.
|
|
status->status = graph::ValidateGraphHasNoCycle(session->graph->graph);
|
|
if (!status->status.ok()) {
|
|
session->graph->mu.unlock();
|
|
return false;
|
|
}
|
|
|
|
GraphDef graph_def;
|
|
*graph_def.mutable_versions() = graph.versions();
|
|
// Fill graph_def with nodes with ids in the range
|
|
// [session->last_num_graph_nodes, num_nodes), that is the nodes
|
|
// added since the last TF_SessionRun() call.
|
|
for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
|
|
Node* const node = graph.FindNodeId(id);
|
|
if (node != nullptr && node->IsOp()) {
|
|
NodeDef* const node_def = graph_def.add_node();
|
|
*node_def = node->def();
|
|
}
|
|
}
|
|
*graph_def.mutable_library() = graph.flib_def().ToProto();
|
|
session->graph->mu.unlock();
|
|
status->status = session->session->Extend(std::move(graph_def));
|
|
if (!status->status.ok()) {
|
|
// Contract is we always delete input_values[i].
|
|
return false;
|
|
}
|
|
// Note: session->session is not modified if Extend() fails, so
|
|
// we only set last_num_graph_nodes if it succeeds.
|
|
session->last_num_graph_nodes = num_nodes;
|
|
} else {
|
|
session->graph->mu.unlock();
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
} // namespace tensorflow
|
|
|
|
static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
|
|
TF_Status* status) {
|
|
status->status = Status::OK();
|
|
for (int i = 0; i < noutputs; ++i) {
|
|
c_outputs[i] = nullptr;
|
|
}
|
|
}
|
|
|
|
static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
|
|
std::vector<std::pair<string, Tensor>>* input_pairs,
|
|
TF_Status* status) {
|
|
const int ninputs = input_pairs->size();
|
|
for (int i = 0; i < ninputs; ++i) {
|
|
status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
|
|
if (!status->status.ok()) return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
|
|
// result in a zero-sized tensor.
|
|
static TF_Tensor* EmptyTensor(TF_DataType dtype,
|
|
const tensorflow::TensorShape& shape) {
|
|
static char empty;
|
|
tensorflow::int64 nelems = 1;
|
|
std::vector<tensorflow::int64> dims;
|
|
for (int i = 0; i < shape.dims(); ++i) {
|
|
dims.push_back(shape.dim_size(i));
|
|
nelems *= shape.dim_size(i);
|
|
}
|
|
CHECK_EQ(nelems, 0);
|
|
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
|
"64-bit int types should match in size");
|
|
return TF_NewTensor(
|
|
dtype, reinterpret_cast<const int64_t*>(dims.data()), shape.dims(),
|
|
reinterpret_cast<void*>(&empty), 0, [](void*, size_t, void*) {}, nullptr);
|
|
}
|
|
|
|
static void TF_Run_Helper(
|
|
Session* session, const char* handle, const TF_Buffer* run_options,
|
|
// Input tensors
|
|
const std::vector<std::pair<string, Tensor>>& input_pairs,
|
|
// Output tensors
|
|
const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs,
|
|
// Target nodes
|
|
const std::vector<string>& target_oper_names, TF_Buffer* run_metadata,
|
|
TF_Status* status) {
|
|
const int noutputs = output_tensor_names.size();
|
|
std::vector<Tensor> outputs(noutputs);
|
|
Status result;
|
|
|
|
if (handle == nullptr) {
|
|
RunOptions run_options_proto;
|
|
if (run_options != nullptr && !run_options_proto.ParseFromArray(
|
|
run_options->data, run_options->length)) {
|
|
status->status = InvalidArgument("Unparseable RunOptions proto");
|
|
return;
|
|
}
|
|
if (run_metadata != nullptr && run_metadata->data != nullptr) {
|
|
status->status =
|
|
InvalidArgument("Passing non-empty run_metadata is invalid.");
|
|
return;
|
|
}
|
|
|
|
RunMetadata run_metadata_proto;
|
|
result = session->Run(run_options_proto, input_pairs, output_tensor_names,
|
|
target_oper_names, &outputs, &run_metadata_proto);
|
|
|
|
// Serialize back to upstream client, who now owns the new buffer
|
|
if (run_metadata != nullptr) {
|
|
status->status = MessageToBuffer(run_metadata_proto, run_metadata);
|
|
if (!status->status.ok()) return;
|
|
}
|
|
} else {
|
|
// NOTE(zongheng): PRun does not support RunOptions yet.
|
|
result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
|
|
}
|
|
if (!result.ok()) {
|
|
status->status = result;
|
|
return;
|
|
}
|
|
|
|
// Store results in c_outputs[]
|
|
for (int i = 0; i < noutputs; ++i) {
|
|
const Tensor& src = outputs[i];
|
|
if (!src.IsInitialized() || src.NumElements() == 0) {
|
|
c_outputs[i] =
|
|
EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
|
|
continue;
|
|
}
|
|
c_outputs[i] = TF_TensorFromTensor(src, &status->status);
|
|
if (!status->status.ok()) return;
|
|
}
|
|
}
|
|
|
|
extern "C" {
|
|
|
|
void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
|
|
// Input tensors
|
|
const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
|
|
// Output tensors
|
|
const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
|
|
// Target nodes
|
|
const char** c_target_oper_names, int ntargets,
|
|
TF_Buffer* run_metadata, TF_Status* status) {
|
|
TF_Run_Setup(noutputs, c_outputs, status);
|
|
std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
|
|
if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
|
|
for (int i = 0; i < ninputs; ++i) {
|
|
input_pairs[i].first = c_input_names[i];
|
|
}
|
|
std::vector<string> output_names(noutputs);
|
|
for (int i = 0; i < noutputs; ++i) {
|
|
output_names[i] = c_output_names[i];
|
|
}
|
|
std::vector<string> target_oper_names(ntargets);
|
|
for (int i = 0; i < ntargets; ++i) {
|
|
target_oper_names[i] = c_target_oper_names[i];
|
|
}
|
|
TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
|
|
c_outputs, target_oper_names, run_metadata, status);
|
|
}
|
|
|
|
void TF_PRunSetup(TF_DeprecatedSession* s,
|
|
// Input names
|
|
const char** c_input_names, int ninputs,
|
|
// Output names
|
|
const char** c_output_names, int noutputs,
|
|
// Target nodes
|
|
const char** c_target_oper_names, int ntargets,
|
|
const char** handle, TF_Status* status) {
|
|
*handle = nullptr;
|
|
|
|
std::vector<string> input_names(ninputs);
|
|
std::vector<string> output_names(noutputs);
|
|
std::vector<string> target_oper_names(ntargets);
|
|
for (int i = 0; i < ninputs; ++i) {
|
|
input_names[i] = c_input_names[i];
|
|
}
|
|
for (int i = 0; i < noutputs; ++i) {
|
|
output_names[i] = c_output_names[i];
|
|
}
|
|
for (int i = 0; i < ntargets; ++i) {
|
|
target_oper_names[i] = c_target_oper_names[i];
|
|
}
|
|
string new_handle;
|
|
status->status = s->session->PRunSetup(input_names, output_names,
|
|
target_oper_names, &new_handle);
|
|
if (status->status.ok()) {
|
|
char* buf = new char[new_handle.size() + 1];
|
|
memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
|
|
*handle = buf;
|
|
}
|
|
}
|
|
|
|
void TF_PRun(TF_DeprecatedSession* s, const char* handle,
|
|
// Input tensors
|
|
const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
|
|
// Output tensors
|
|
const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
|
|
// Target nodes
|
|
const char** c_target_oper_names, int ntargets,
|
|
TF_Status* status) {
|
|
TF_Run_Setup(noutputs, c_outputs, status);
|
|
std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
|
|
if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
|
|
for (int i = 0; i < ninputs; ++i) {
|
|
input_pairs[i].first = c_input_names[i];
|
|
}
|
|
|
|
std::vector<string> output_names(noutputs);
|
|
for (int i = 0; i < noutputs; ++i) {
|
|
output_names[i] = c_output_names[i];
|
|
}
|
|
std::vector<string> target_oper_names(ntargets);
|
|
for (int i = 0; i < ntargets; ++i) {
|
|
target_oper_names[i] = c_target_oper_names[i];
|
|
}
|
|
TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names,
|
|
c_outputs, target_oper_names, nullptr, status);
|
|
}
|
|
|
|
TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
|
|
TF_Library* lib_handle = new TF_Library;
|
|
status->status = tensorflow::LoadLibrary(
|
|
library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
|
|
&lib_handle->op_list.length);
|
|
if (!status->status.ok()) {
|
|
delete lib_handle;
|
|
return nullptr;
|
|
}
|
|
return lib_handle;
|
|
}
|
|
|
|
TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
|
|
|
|
void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
|
|
if (lib_handle == nullptr) return;
|
|
tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data));
|
|
delete lib_handle;
|
|
}
|
|
|
|
TF_Buffer* TF_GetAllOpList() {
|
|
std::vector<tensorflow::OpDef> op_defs;
|
|
tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs);
|
|
tensorflow::OpList op_list;
|
|
for (const auto& op : op_defs) {
|
|
*(op_list.add_op()) = op;
|
|
}
|
|
TF_Buffer* ret = TF_NewBuffer();
|
|
TF_CHECK_OK(MessageToBuffer(op_list, ret));
|
|
return ret;
|
|
}
|
|
|
|
// --------------------------------------------------------------------------
|
|
// ListDevices & SessionListDevices API
|
|
|
|
void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; }
|
|
|
|
TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
|
|
TF_DeviceList* response = new TF_DeviceList;
|
|
if (session && session->session)
|
|
status->status = session->session->ListDevices(&response->response);
|
|
return response;
|
|
}
|
|
|
|
TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
|
|
TF_Status* status) {
|
|
TF_DeviceList* response = new TF_DeviceList;
|
|
if (session && session->session)
|
|
status->status = session->session->ListDevices(&response->response);
|
|
return response;
|
|
}
|
|
|
|
int TF_DeviceListCount(const TF_DeviceList* list) {
|
|
return list->response.size();
|
|
}
|
|
|
|
#define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \
|
|
return_type method_name(const TF_DeviceList* list, const int index, \
|
|
TF_Status* status) { \
|
|
if (list == nullptr) { \
|
|
status->status = InvalidArgument("list is null!"); \
|
|
return err_val; \
|
|
} \
|
|
if (index < 0 || index >= list->response.size()) { \
|
|
status->status = InvalidArgument("index out of bounds"); \
|
|
return err_val; \
|
|
} \
|
|
status->status = Status::OK(); \
|
|
return list->response[index].accessor; \
|
|
}
|
|
|
|
TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
|
|
TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
|
|
nullptr);
|
|
TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
|
|
TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0);
|
|
|
|
#undef TF_DEVICELIST_METHOD
|
|
|
|
} // end extern "C"
|
|
|
|
// --------------------------------------------------------------------------
|
|
// New Graph and Session API
|
|
|
|
// Helper functions -----------------------------------------------------------
|
|
|
|
namespace {
|
|
|
|
TF_Operation* ToOperation(Node* node) {
|
|
return static_cast<TF_Operation*>(static_cast<void*>(node));
|
|
}
|
|
|
|
string OutputName(const TF_Output& output) {
|
|
return StrCat(output.oper->node.name(), ":", output.index);
|
|
}
|
|
|
|
const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
|
|
const char* attr_name,
|
|
TF_Status* status) {
|
|
const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
|
|
if (attr == nullptr) {
|
|
status->status = InvalidArgument("Operation '", oper->node.name(),
|
|
"' has no attr named '", attr_name, "'.");
|
|
}
|
|
return attr;
|
|
}
|
|
|
|
TensorId ToTensorId(const TF_Output& output) {
|
|
return TensorId(output.oper->node.name(), output.index);
|
|
}
|
|
|
|
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
|
std::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs,
|
|
int n) {
|
|
std::vector<tensorflow::Output> outputs(n);
|
|
for (int i = 0; i < n; ++i) {
|
|
outputs[i] =
|
|
tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index);
|
|
}
|
|
return outputs;
|
|
}
|
|
|
|
void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
|
|
TF_Output* tf_outputs) {
|
|
for (int i = 0; i < outputs.size(); i++) {
|
|
tf_outputs[i].oper = ToOperation(outputs[i].node());
|
|
tf_outputs[i].index = outputs[i].index();
|
|
}
|
|
}
|
|
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
|
|
|
} // namespace
|
|
|
|
// Shape functions -----------------------------------------------------------
|
|
|
|
void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
|
|
const int64_t* dims, const int num_dims,
|
|
TF_Status* status) {
|
|
Node* node = &output.oper->node;
|
|
|
|
mutex_lock l(graph->mu);
|
|
tensorflow::shape_inference::InferenceContext* ic =
|
|
graph->refiner.GetContext(node);
|
|
if (ic == nullptr) {
|
|
status->status =
|
|
InvalidArgument("Node ", node->name(), " was not found in the graph");
|
|
return;
|
|
}
|
|
tensorflow::shape_inference::ShapeHandle new_shape =
|
|
tensorflow::ShapeHandleFromDims(ic, num_dims, dims);
|
|
status->status = graph->refiner.SetShape(node, output.index, new_shape);
|
|
}
|
|
|
|
int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output,
|
|
TF_Status* status) {
|
|
Node* node = &output.oper->node;
|
|
|
|
mutex_lock l(graph->mu);
|
|
tensorflow::shape_inference::InferenceContext* ic =
|
|
graph->refiner.GetContext(node);
|
|
if (ic == nullptr) {
|
|
status->status =
|
|
InvalidArgument("Node ", node->name(), " was not found in the graph");
|
|
return -1;
|
|
}
|
|
|
|
tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
|
|
|
|
// Unknown rank means the number of dimensions is -1.
|
|
if (!ic->RankKnown(shape)) {
|
|
return -1;
|
|
}
|
|
|
|
return ic->Rank(shape);
|
|
}
|
|
|
|
void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,
|
|
int num_dims, TF_Status* status) {
|
|
Node* node = &output.oper->node;
|
|
|
|
mutex_lock l(graph->mu);
|
|
tensorflow::shape_inference::InferenceContext* ic =
|
|
graph->refiner.GetContext(node);
|
|
if (ic == nullptr) {
|
|
status->status =
|
|
InvalidArgument("Node ", node->name(), " was not found in the graph");
|
|
return;
|
|
}
|
|
|
|
tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
|
|
|
|
int rank = -1;
|
|
if (ic->RankKnown(shape)) {
|
|
rank = ic->Rank(shape);
|
|
}
|
|
|
|
if (num_dims != rank) {
|
|
status->status = InvalidArgument("Expected rank is ", num_dims,
|
|
" but actual rank is ", rank);
|
|
return;
|
|
}
|
|
|
|
if (num_dims == 0) {
|
|
// Output shape is a scalar.
|
|
return;
|
|
}
|
|
|
|
// Rank is greater than 0, so fill in the values, if known, and
|
|
// -1 for unknown values.
|
|
for (int i = 0; i < num_dims; ++i) {
|
|
auto dim = ic->Dim(shape, i);
|
|
tensorflow::int64 value = -1;
|
|
if (ic->ValueKnown(dim)) {
|
|
value = ic->Value(dim);
|
|
}
|
|
dims[i] = value;
|
|
}
|
|
}
|
|
|
|
// TF_OperationDescription functions ------------------------------------------
|
|
|
|
extern "C" {
|
|
|
|
static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
|
|
const char* op_type,
|
|
const char* oper_name)
|
|
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
|
|
return new TF_OperationDescription(graph, op_type, oper_name);
|
|
}
|
|
|
|
TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type,
|
|
const char* oper_name) {
|
|
mutex_lock l(graph->mu);
|
|
return TF_NewOperationLocked(graph, op_type, oper_name);
|
|
}
|
|
|
|
void TF_SetDevice(TF_OperationDescription* desc, const char* device) {
|
|
desc->node_builder.Device(device);
|
|
}
|
|
|
|
void TF_AddInput(TF_OperationDescription* desc, TF_Output input) {
|
|
desc->node_builder.Input(&input.oper->node, input.index);
|
|
}
|
|
|
|
void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs,
|
|
int num_inputs) {
|
|
std::vector<NodeBuilder::NodeOut> input_list;
|
|
input_list.reserve(num_inputs);
|
|
for (int i = 0; i < num_inputs; ++i) {
|
|
input_list.emplace_back(&inputs[i].oper->node, inputs[i].index);
|
|
}
|
|
desc->node_builder.Input(input_list);
|
|
}
|
|
|
|
void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
|
|
desc->node_builder.ControlInput(&input->node);
|
|
}
|
|
|
|
void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) {
|
|
desc->colocation_constraints.emplace(
|
|
StrCat(tensorflow::kColocationGroupPrefix, op->node.name()));
|
|
}
|
|
|
|
void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
|
|
const void* value, size_t length) {
|
|
tensorflow::StringPiece s(static_cast<const char*>(value), length);
|
|
desc->node_builder.Attr(attr_name, s);
|
|
}
|
|
|
|
void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
|
|
const void* const* values, const size_t* lengths,
|
|
int num_values) {
|
|
if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
|
|
desc->colocation_constraints.clear();
|
|
for (int i = 0; i < num_values; ++i) {
|
|
desc->colocation_constraints.emplace(static_cast<const char*>(values[i]),
|
|
lengths[i]);
|
|
}
|
|
} else {
|
|
std::vector<tensorflow::StringPiece> v;
|
|
v.reserve(num_values);
|
|
for (int i = 0; i < num_values; ++i) {
|
|
v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
|
|
}
|
|
desc->node_builder.Attr(attr_name, v);
|
|
}
|
|
}
|
|
|
|
void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
|
|
int64_t value) {
|
|
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
|
"64-bit int types should match in size");
|
|
desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
|
|
}
|
|
|
|
void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name,
|
|
const int64_t* values, int num_values) {
|
|
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
|
"64-bit int types should match in size");
|
|
desc->node_builder.Attr(
|
|
attr_name,
|
|
ArraySlice<const tensorflow::int64>(
|
|
reinterpret_cast<const tensorflow::int64*>(values), num_values));
|
|
}
|
|
|
|
void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name,
|
|
float value) {
|
|
desc->node_builder.Attr(attr_name, value);
|
|
}
|
|
|
|
void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name,
|
|
const float* values, int num_values) {
|
|
desc->node_builder.Attr(attr_name,
|
|
ArraySlice<const float>(values, num_values));
|
|
}
|
|
|
|
void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
|
|
unsigned char value) {
|
|
desc->node_builder.Attr(attr_name, static_cast<bool>(value));
|
|
}
|
|
|
|
void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name,
|
|
const unsigned char* values, int num_values) {
|
|
std::unique_ptr<bool[]> b(new bool[num_values]);
|
|
for (int i = 0; i < num_values; ++i) {
|
|
b[i] = values[i];
|
|
}
|
|
desc->node_builder.Attr(attr_name,
|
|
ArraySlice<const bool>(b.get(), num_values));
|
|
}
|
|
|
|
void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
|
|
TF_DataType value) {
|
|
desc->node_builder.Attr(attr_name, static_cast<DataType>(value));
|
|
}
|
|
|
|
void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
|
|
const TF_DataType* values, int num_values) {
|
|
desc->node_builder.Attr(
|
|
attr_name, ArraySlice<const DataType>(
|
|
reinterpret_cast<const DataType*>(values), num_values));
|
|
}
|
|
|
|
void TF_SetAttrPlaceholder(TF_OperationDescription* desc, const char* attr_name,
|
|
const char* placeholder) {
|
|
tensorflow::AttrValue attr_value;
|
|
attr_value.set_placeholder(placeholder);
|
|
desc->node_builder.Attr(attr_name, attr_value);
|
|
}
|
|
|
|
void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name,
|
|
const char* value, size_t length) {
|
|
tensorflow::NameAttrList func_name;
|
|
func_name.set_name(string(value, value + length));
|
|
desc->node_builder.Attr(attr_name, func_name);
|
|
}
|
|
|
|
void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name,
|
|
const int64_t* dims, int num_dims) {
|
|
PartialTensorShape shape;
|
|
if (num_dims >= 0) {
|
|
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
|
"64-bit int types should match in size");
|
|
shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
|
|
reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
|
|
}
|
|
desc->node_builder.Attr(attr_name, shape);
|
|
}
|
|
|
|
void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name,
|
|
const int64_t* const* dims, const int* num_dims,
|
|
int num_shapes) {
|
|
std::vector<PartialTensorShape> shapes;
|
|
shapes.reserve(num_shapes);
|
|
for (int i = 0; i < num_shapes; ++i) {
|
|
if (num_dims[i] < 0) {
|
|
shapes.emplace_back();
|
|
} else {
|
|
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
|
"64-bit int types should match in size");
|
|
shapes.emplace_back(ArraySlice<tensorflow::int64>(
|
|
reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
|
|
}
|
|
}
|
|
desc->node_builder.Attr(attr_name, shapes);
|
|
}
|
|
|
|
void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
|
|
const char* attr_name, const void* proto,
|
|
size_t proto_len, TF_Status* status) {
|
|
// shape.ParseFromArray takes an int as length, this function takes size_t,
|
|
// make sure there is no information loss.
|
|
if (proto_len > std::numeric_limits<int>::max()) {
|
|
status->status = InvalidArgument(
|
|
"proto_len (", proto_len,
|
|
" bytes) is too large to be parsed by the protocol buffer library");
|
|
return;
|
|
}
|
|
TensorShapeProto shape;
|
|
if (shape.ParseFromArray(proto, static_cast<int>(proto_len))) {
|
|
desc->node_builder.Attr(attr_name, shape);
|
|
status->status = Status::OK();
|
|
} else {
|
|
status->status = InvalidArgument("Unparseable TensorShapeProto");
|
|
}
|
|
}
|
|
|
|
void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
|
|
const char* attr_name,
|
|
const void* const* protos,
|
|
const size_t* proto_lens, int num_shapes,
|
|
TF_Status* status) {
|
|
std::vector<TensorShapeProto> shapes;
|
|
shapes.resize(num_shapes);
|
|
for (int i = 0; i < num_shapes; ++i) {
|
|
if (proto_lens[i] > std::numeric_limits<int>::max()) {
|
|
status->status = InvalidArgument(
|
|
"length of element ", i, " in the list (", proto_lens[i],
|
|
" bytes) is too large to be parsed by the protocol buffer library");
|
|
return;
|
|
}
|
|
if (!shapes[i].ParseFromArray(protos[i], static_cast<int>(proto_lens[i]))) {
|
|
status->status =
|
|
InvalidArgument("Unparseable TensorShapeProto at index ", i);
|
|
return;
|
|
}
|
|
}
|
|
desc->node_builder.Attr(attr_name, shapes);
|
|
status->status = Status::OK();
|
|
}
|
|
|
|
void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
|
|
TF_Tensor* value, TF_Status* status) {
|
|
Tensor t;
|
|
status->status = TF_TensorToTensor(value, &t);
|
|
if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
|
|
}
|
|
|
|
void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
|
|
TF_Tensor* const* values, int num_values,
|
|
TF_Status* status) {
|
|
status->status = Status::OK();
|
|
std::vector<Tensor> t;
|
|
t.reserve(num_values);
|
|
|
|
for (int i = 0; i < num_values && status->status.ok(); ++i) {
|
|
Tensor v;
|
|
status->status = TF_TensorToTensor(values[i], &v);
|
|
t.emplace_back(v);
|
|
}
|
|
|
|
if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
|
|
}
|
|
|
|
void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
|
|
const void* proto, size_t proto_len,
|
|
TF_Status* status) {
|
|
tensorflow::AttrValue attr_value;
|
|
if (!attr_value.ParseFromArray(proto, proto_len)) {
|
|
status->status = InvalidArgument("Unparseable AttrValue proto");
|
|
return;
|
|
}
|
|
|
|
if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
|
|
if (attr_value.value_case() != tensorflow::AttrValue::kList &&
|
|
attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) {
|
|
status->status =
|
|
InvalidArgument("Expected \"list\" field for \"",
|
|
tensorflow::kColocationAttrName, "\" attribute");
|
|
return;
|
|
}
|
|
desc->colocation_constraints.clear();
|
|
for (const string& location : attr_value.list().s()) {
|
|
desc->colocation_constraints.insert(location);
|
|
}
|
|
} else {
|
|
desc->node_builder.Attr(attr_name, std::move(attr_value));
|
|
}
|
|
|
|
status->status = Status::OK();
|
|
}
|
|
|
|
static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
|
|
TF_Status* status)
|
|
TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
|
|
Node* ret = nullptr;
|
|
|
|
if (desc->graph->name_map.count(desc->node_builder.node_name())) {
|
|
status->status = InvalidArgument("Duplicate node name in graph: '",
|
|
desc->node_builder.node_name(), "'");
|
|
} else {
|
|
if (!desc->colocation_constraints.empty()) {
|
|
desc->node_builder.Attr(
|
|
tensorflow::kColocationAttrName,
|
|
std::vector<string>(desc->colocation_constraints.begin(),
|
|
desc->colocation_constraints.end()));
|
|
}
|
|
status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret,
|
|
/*consume=*/true);
|
|
|
|
if (status->status.ok()) {
|
|
// Run shape inference function for newly added node.
|
|
status->status = desc->graph->refiner.AddNode(ret);
|
|
}
|
|
if (status->status.ok()) {
|
|
// Add the node to the name-to-node mapping.
|
|
desc->graph->name_map[ret->name()] = ret;
|
|
} else if (ret != nullptr) {
|
|
desc->graph->graph.RemoveNode(ret);
|
|
ret = nullptr;
|
|
}
|
|
}
|
|
|
|
delete desc;
|
|
|
|
return ToOperation(ret);
|
|
}
|
|
|
|
TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
|
|
TF_Status* status) {
|
|
mutex_lock l(desc->graph->mu);
|
|
return TF_FinishOperationLocked(desc, status);
|
|
}
|
|
|
|
// TF_Operation functions
|
|
// ----------------------------------------------------------
|
|
|
|
const char* TF_OperationName(TF_Operation* oper) {
|
|
return oper->node.name().c_str();
|
|
}
|
|
|
|
const char* TF_OperationOpType(TF_Operation* oper) {
|
|
return oper->node.type_string().c_str();
|
|
}
|
|
|
|
const char* TF_OperationDevice(TF_Operation* oper) {
|
|
return oper->node.requested_device().c_str();
|
|
}
|
|
|
|
int TF_OperationNumOutputs(TF_Operation* oper) {
|
|
return oper->node.num_outputs();
|
|
}
|
|
|
|
TF_DataType TF_OperationOutputType(TF_Output oper_out) {
|
|
return static_cast<TF_DataType>(
|
|
oper_out.oper->node.output_type(oper_out.index));
|
|
}
|
|
|
|
int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
|
|
TF_Status* status) {
|
|
NameRangeMap name_ranges;
|
|
status->status =
|
|
NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
|
|
if (!status->status.ok()) return -1;
|
|
auto iter = name_ranges.find(arg_name);
|
|
if (iter == name_ranges.end()) {
|
|
status->status = InvalidArgument("Output arg '", arg_name, "' not found");
|
|
return -1;
|
|
}
|
|
return iter->second.second - iter->second.first;
|
|
}
|
|
|
|
int TF_OperationNumInputs(TF_Operation* oper) {
|
|
return oper->node.num_inputs();
|
|
}
|
|
|
|
TF_DataType TF_OperationInputType(TF_Input oper_in) {
|
|
return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index));
|
|
}
|
|
|
|
int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
|
|
TF_Status* status) {
|
|
NameRangeMap name_ranges;
|
|
status->status =
|
|
NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
|
|
if (!status->status.ok()) return -1;
|
|
auto iter = name_ranges.find(arg_name);
|
|
if (iter == name_ranges.end()) {
|
|
status->status = InvalidArgument("Input arg '", arg_name, "' not found");
|
|
return -1;
|
|
}
|
|
return iter->second.second - iter->second.first;
|
|
}
|
|
|
|
TF_Output TF_OperationInput(TF_Input oper_in) {
|
|
const tensorflow::Edge* edge;
|
|
Status s = oper_in.oper->node.input_edge(oper_in.index, &edge);
|
|
if (!s.ok()) {
|
|
return {nullptr, -1};
|
|
}
|
|
|
|
return {ToOperation(edge->src()), edge->src_output()};
|
|
}
|
|
|
|
void TF_OperationAllInputs(TF_Operation* oper, TF_Output* inputs,
|
|
int max_inputs) {
|
|
for (auto* edge : oper->node.in_edges()) {
|
|
if (edge->dst_input() >= 0 && edge->dst_input() < max_inputs) {
|
|
inputs[edge->dst_input()] = {ToOperation(edge->src()),
|
|
edge->src_output()};
|
|
}
|
|
}
|
|
}
|
|
|
|
int TF_OperationOutputNumConsumers(TF_Output oper_out) {
|
|
int count = 0;
|
|
for (const auto* edge : oper_out.oper->node.out_edges()) {
|
|
if (edge->src_output() == oper_out.index) {
|
|
++count;
|
|
}
|
|
}
|
|
return count;
|
|
}
|
|
|
|
int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers,
|
|
int max_consumers) {
|
|
int count = 0;
|
|
for (const auto* edge : oper_out.oper->node.out_edges()) {
|
|
if (edge->src_output() == oper_out.index) {
|
|
if (count < max_consumers) {
|
|
consumers[count] = {ToOperation(edge->dst()), edge->dst_input()};
|
|
}
|
|
++count;
|
|
}
|
|
}
|
|
return count;
|
|
}
|
|
|
|
int TF_OperationNumControlInputs(TF_Operation* oper) {
|
|
int count = 0;
|
|
for (const auto* edge : oper->node.in_edges()) {
|
|
if (edge->IsControlEdge() && !edge->src()->IsSource()) {
|
|
++count;
|
|
}
|
|
}
|
|
return count;
|
|
}
|
|
|
|
int TF_OperationGetControlInputs(TF_Operation* oper,
|
|
TF_Operation** control_inputs,
|
|
int max_control_inputs) {
|
|
int count = 0;
|
|
for (const auto* edge : oper->node.in_edges()) {
|
|
if (edge->IsControlEdge() && !edge->src()->IsSource()) {
|
|
if (count < max_control_inputs) {
|
|
control_inputs[count] = ToOperation(edge->src());
|
|
}
|
|
++count;
|
|
}
|
|
}
|
|
return count;
|
|
}
|
|
|
|
int TF_OperationNumControlOutputs(TF_Operation* oper) {
|
|
int count = 0;
|
|
for (const auto* edge : oper->node.out_edges()) {
|
|
if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
|
|
++count;
|
|
}
|
|
}
|
|
return count;
|
|
}
|
|
|
|
int TF_OperationGetControlOutputs(TF_Operation* oper,
|
|
TF_Operation** control_outputs,
|
|
int max_control_outputs) {
|
|
int count = 0;
|
|
for (const auto* edge : oper->node.out_edges()) {
|
|
if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
|
|
if (count < max_control_outputs) {
|
|
control_outputs[count] = ToOperation(edge->dst());
|
|
}
|
|
++count;
|
|
}
|
|
}
|
|
return count;
|
|
}
|
|
|
|
TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
|
|
const char* attr_name,
|
|
TF_Status* status) {
|
|
TF_AttrMetadata metadata;
|
|
const auto* attr = GetAttrValue(oper, attr_name, status);
|
|
if (!status->status.ok()) return metadata;
|
|
switch (attr->value_case()) {
|
|
#define SINGLE_CASE(kK, attr_type, size_expr) \
|
|
case tensorflow::AttrValue::kK: \
|
|
metadata.is_list = 0; \
|
|
metadata.list_size = -1; \
|
|
metadata.type = attr_type; \
|
|
metadata.total_size = size_expr; \
|
|
break;
|
|
|
|
SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
|
|
SINGLE_CASE(kI, TF_ATTR_INT, -1);
|
|
SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
|
|
SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
|
|
SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
|
|
SINGLE_CASE(kShape, TF_ATTR_SHAPE,
|
|
attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
|
|
SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
|
|
#undef SINGLE_CASE
|
|
|
|
case tensorflow::AttrValue::kList:
|
|
metadata.is_list = 1;
|
|
metadata.list_size = 0;
|
|
metadata.total_size = -1;
|
|
#define LIST_CASE(field, attr_type, ...) \
|
|
if (attr->list().field##_size() > 0) { \
|
|
metadata.type = attr_type; \
|
|
metadata.list_size = attr->list().field##_size(); \
|
|
__VA_ARGS__; \
|
|
break; \
|
|
}
|
|
|
|
LIST_CASE(
|
|
s, TF_ATTR_STRING, metadata.total_size = 0;
|
|
for (int i = 0; i < attr->list().s_size();
|
|
++i) { metadata.total_size += attr->list().s(i).size(); });
|
|
LIST_CASE(i, TF_ATTR_INT);
|
|
LIST_CASE(f, TF_ATTR_FLOAT);
|
|
LIST_CASE(b, TF_ATTR_BOOL);
|
|
LIST_CASE(type, TF_ATTR_TYPE);
|
|
LIST_CASE(
|
|
shape, TF_ATTR_SHAPE, metadata.total_size = 0;
|
|
for (int i = 0; i < attr->list().shape_size(); ++i) {
|
|
const auto& s = attr->list().shape(i);
|
|
metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
|
|
});
|
|
LIST_CASE(tensor, TF_ATTR_TENSOR);
|
|
LIST_CASE(tensor, TF_ATTR_FUNC);
|
|
#undef LIST_CASE
|
|
// All lists empty, determine the type from the OpDef.
|
|
if (metadata.list_size == 0) {
|
|
for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
|
|
const auto& a = oper->node.op_def().attr(i);
|
|
if (a.name() != attr_name) continue;
|
|
const string& typestr = a.type();
|
|
if (typestr == "list(string)") {
|
|
metadata.type = TF_ATTR_STRING;
|
|
} else if (typestr == "list(int)") {
|
|
metadata.type = TF_ATTR_INT;
|
|
} else if (typestr == "list(float)") {
|
|
metadata.type = TF_ATTR_FLOAT;
|
|
} else if (typestr == "list(bool)") {
|
|
metadata.type = TF_ATTR_BOOL;
|
|
} else if (typestr == "list(type)") {
|
|
metadata.type = TF_ATTR_TYPE;
|
|
} else if (typestr == "list(shape)") {
|
|
metadata.type = TF_ATTR_SHAPE;
|
|
} else if (typestr == "list(tensor)") {
|
|
metadata.type = TF_ATTR_TENSOR;
|
|
} else if (typestr == "list(func)") {
|
|
metadata.type = TF_ATTR_FUNC;
|
|
} else {
|
|
status->status = InvalidArgument(
|
|
"Attribute '", attr_name,
|
|
"' has an empty value of an unrecognized type '", typestr, "'");
|
|
return metadata;
|
|
}
|
|
}
|
|
}
|
|
break;
|
|
|
|
case tensorflow::AttrValue::kPlaceholder:
|
|
metadata.is_list = 0;
|
|
metadata.list_size = -1;
|
|
metadata.type = TF_ATTR_PLACEHOLDER;
|
|
metadata.total_size = -1;
|
|
break;
|
|
|
|
case tensorflow::AttrValue::kFunc:
|
|
metadata.is_list = 0;
|
|
metadata.list_size = -1;
|
|
metadata.type = TF_ATTR_FUNC;
|
|
metadata.total_size = -1;
|
|
break;
|
|
|
|
case tensorflow::AttrValue::VALUE_NOT_SET:
|
|
status->status =
|
|
InvalidArgument("Attribute '", attr_name, "' has no value set");
|
|
break;
|
|
}
|
|
return metadata;
|
|
}
|
|
|
|
void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
|
|
void* value, size_t max_length,
|
|
TF_Status* status) {
|
|
const auto* attr = GetAttrValue(oper, attr_name, status);
|
|
if (!status->status.ok()) return;
|
|
if (attr->value_case() != tensorflow::AttrValue::kS) {
|
|
status->status =
|
|
InvalidArgument("Attribute '", attr_name, "' is not a string");
|
|
return;
|
|
}
|
|
if (max_length <= 0) {
|
|
return;
|
|
}
|
|
const auto& s = attr->s();
|
|
std::memcpy(value, s.data(), std::min<size_t>(s.length(), max_length));
|
|
}
|
|
|
|
void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
|
|
void** values, size_t* lengths,
|
|
int max_values, void* storage,
|
|
size_t storage_size, TF_Status* status) {
|
|
const auto* attr = GetAttrValue(oper, attr_name, status);
|
|
if (!status->status.ok()) return;
|
|
if (attr->value_case() != tensorflow::AttrValue::kList) {
|
|
status->status =
|
|
InvalidArgument("Value for '", attr_name, "' is not a list");
|
|
return;
|
|
}
|
|
const auto len = std::min(max_values, attr->list().s_size());
|
|
char* p = static_cast<char*>(storage);
|
|
for (int i = 0; i < len; ++i) {
|
|
const string& s = attr->list().s(i);
|
|
values[i] = p;
|
|
lengths[i] = s.size();
|
|
if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
|
|
status->status = InvalidArgument(
|
|
"Not enough storage to hold the requested list of strings");
|
|
return;
|
|
}
|
|
memcpy(values[i], s.data(), s.size());
|
|
p += s.size();
|
|
}
|
|
}
|
|
|
|
#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \
|
|
void func(TF_Operation* oper, const char* attr_name, c_type* value, \
|
|
TF_Status* status) { \
|
|
cpp_type v; \
|
|
status->status = \
|
|
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \
|
|
if (!status->status.ok()) return; \
|
|
*value = static_cast<c_type>(v); \
|
|
} \
|
|
void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
|
|
int max_values, TF_Status* status) { \
|
|
const auto* attr = GetAttrValue(oper, attr_name, status); \
|
|
if (!status->status.ok()) return; \
|
|
if (attr->value_case() != tensorflow::AttrValue::kList) { \
|
|
status->status = \
|
|
InvalidArgument("Value for '", attr_name, "' is not a list."); \
|
|
return; \
|
|
} \
|
|
const auto len = std::min(max_values, attr->list().list_field##_size()); \
|
|
for (int i = 0; i < len; ++i) { \
|
|
values[i] = static_cast<c_type>(attr->list().list_field(i)); \
|
|
} \
|
|
}
|
|
DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i);
|
|
DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f);
|
|
DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b);
|
|
DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type);
|
|
#undef DEFINE_GETATTR
|
|
|
|
void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
|
|
int64_t* value, int num_dims, TF_Status* status) {
|
|
PartialTensorShape shape;
|
|
status->status =
|
|
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
|
|
if (!status->status.ok()) return;
|
|
auto len = std::min(shape.dims(), num_dims);
|
|
for (int i = 0; i < len; ++i) {
|
|
value[i] = shape.dim_size(i);
|
|
}
|
|
}
|
|
|
|
void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
|
|
int64_t** dims, int* num_dims, int num_shapes,
|
|
int64_t* storage, int storage_size,
|
|
TF_Status* status) {
|
|
std::vector<PartialTensorShape> shapes;
|
|
status->status =
|
|
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
|
|
if (!status->status.ok()) return;
|
|
auto len = std::min(static_cast<int>(shapes.size()), num_shapes);
|
|
int64_t* p = storage;
|
|
int storage_left = storage_size;
|
|
for (int i = 0; i < len; ++i) {
|
|
// shapes[i].dims() == -1 for shapes with an unknown rank.
|
|
int64_t n = shapes[i].dims();
|
|
num_dims[i] = n;
|
|
dims[i] = p;
|
|
if (n < 0) {
|
|
continue;
|
|
}
|
|
if (storage_left < n) {
|
|
status->status = InvalidArgument(
|
|
"Not enough storage to hold the requested list of shapes");
|
|
return;
|
|
}
|
|
storage_left -= n;
|
|
for (int j = 0; j < n; ++j, ++p) {
|
|
*p = shapes[i].dim_size(j);
|
|
}
|
|
}
|
|
}
|
|
|
|
void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
|
|
const char* attr_name,
|
|
TF_Buffer* value, TF_Status* status) {
|
|
const auto* attr = GetAttrValue(oper, attr_name, status);
|
|
if (!status->status.ok()) return;
|
|
if (attr->value_case() != tensorflow::AttrValue::kShape) {
|
|
status->status =
|
|
InvalidArgument("Value for '", attr_name, "' is not a shape.");
|
|
return;
|
|
}
|
|
status->status = MessageToBuffer(attr->shape(), value);
|
|
}
|
|
|
|
void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
|
|
const char* attr_name,
|
|
TF_Buffer** values, int max_values,
|
|
TF_Status* status) {
|
|
const auto* attr = GetAttrValue(oper, attr_name, status);
|
|
if (!status->status.ok()) return;
|
|
if (attr->value_case() != tensorflow::AttrValue::kList) {
|
|
status->status =
|
|
InvalidArgument("Value for '", attr_name, "' is not a list");
|
|
return;
|
|
}
|
|
const auto len = std::min(max_values, attr->list().shape_size());
|
|
for (int i = 0; i < len; ++i) {
|
|
values[i] = TF_NewBuffer();
|
|
status->status = MessageToBuffer(attr->list().shape(i), values[i]);
|
|
if (!status->status.ok()) {
|
|
// Delete everything allocated to far, the operation has failed.
|
|
for (int j = 0; j <= i; ++j) {
|
|
TF_DeleteBuffer(values[j]);
|
|
}
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
|
|
TF_Tensor** value, TF_Status* status) {
|
|
*value = nullptr;
|
|
Tensor t;
|
|
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
|
|
if (!status->status.ok()) return;
|
|
*value = TF_TensorFromTensor(t, &status->status);
|
|
}
|
|
|
|
void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
|
|
TF_Tensor** values, int max_values,
|
|
TF_Status* status) {
|
|
std::vector<Tensor> ts;
|
|
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
|
|
if (!status->status.ok()) return;
|
|
const auto len = std::min(max_values, static_cast<int>(ts.size()));
|
|
for (int i = 0; i < len; ++i) {
|
|
values[i] = TF_TensorFromTensor(ts[i], &status->status);
|
|
}
|
|
}
|
|
|
|
void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
|
|
TF_Buffer* output_attr_value,
|
|
TF_Status* status) {
|
|
const auto* attr = GetAttrValue(oper, attr_name, status);
|
|
if (!status->status.ok()) return;
|
|
status->status = MessageToBuffer(*attr, output_attr_value);
|
|
}
|
|
|
|
void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
|
|
TF_Status* status) {
|
|
status->status = MessageToBuffer(oper->node.def(), output_node_def);
|
|
}
|
|
|
|
// TF_Graph functions ---------------------------------------------------------
|
|
|
|
TF_Graph::TF_Graph()
|
|
: graph(tensorflow::OpRegistry::Global()),
|
|
refiner(graph.versions().producer(), graph.op_registry()),
|
|
delete_requested(false),
|
|
parent(nullptr),
|
|
parent_inputs(nullptr) {
|
|
// Tell the shape refiner to also run shape inference on functions.
|
|
refiner.set_function_library_for_shape_inference(&graph.flib_def());
|
|
}
|
|
|
|
TF_Graph* TF_NewGraph() { return new TF_Graph; }
|
|
|
|
void TF_DeleteGraph(TF_Graph* g) {
|
|
if (g == nullptr) return;
|
|
g->mu.lock();
|
|
g->delete_requested = true;
|
|
const bool del = g->sessions.empty();
|
|
g->mu.unlock();
|
|
if (del) delete g;
|
|
}
|
|
|
|
TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) {
|
|
mutex_lock l(graph->mu);
|
|
auto iter = graph->name_map.find(oper_name);
|
|
if (iter == graph->name_map.end()) {
|
|
return nullptr;
|
|
} else {
|
|
return ToOperation(iter->second);
|
|
}
|
|
}
|
|
|
|
TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) {
|
|
if (*pos == 0) {
|
|
// Advance past the first sentinel nodes in every graph (the source & sink).
|
|
*pos += 2;
|
|
} else {
|
|
// Advance to the next node.
|
|
*pos += 1;
|
|
}
|
|
|
|
mutex_lock l(graph->mu);
|
|
while (*pos < static_cast<size_t>(graph->graph.num_node_ids())) {
|
|
Node* node = graph->graph.FindNodeId(*pos);
|
|
// FindNodeId() returns nullptr for nodes that have been deleted.
|
|
// We aren't currently allowing nodes to be deleted, but it is safer
|
|
// to still check.
|
|
if (node != nullptr) return ToOperation(node);
|
|
*pos += 1;
|
|
}
|
|
|
|
// No more nodes.
|
|
return nullptr;
|
|
}
|
|
|
|
void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
|
|
TF_Status* status) {
|
|
GraphDef def;
|
|
{
|
|
mutex_lock l(graph->mu);
|
|
graph->graph.ToGraphDef(&def);
|
|
}
|
|
status->status = MessageToBuffer(def, output_graph_def);
|
|
}
|
|
|
|
void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
|
|
TF_Buffer* output_op_def, TF_Status* status) {
|
|
const OpDef* op_def;
|
|
{
|
|
mutex_lock l(graph->mu);
|
|
status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
|
|
if (!status->status.ok()) return;
|
|
}
|
|
status->status = MessageToBuffer(*op_def, output_op_def);
|
|
}
|
|
|
|
void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def,
|
|
TF_Status* status) {
|
|
VersionDef versions;
|
|
{
|
|
mutex_lock l(graph->mu);
|
|
versions = graph->graph.versions();
|
|
}
|
|
status->status = MessageToBuffer(versions, output_version_def);
|
|
}
|
|
|
|
TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
|
|
return new TF_ImportGraphDefOptions;
|
|
}
|
|
void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) {
|
|
delete opts;
|
|
}
|
|
void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
|
|
const char* prefix) {
|
|
opts->opts.prefix = prefix;
|
|
}
|
|
void TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions* opts,
|
|
const char* device) {
|
|
opts->opts.default_device = device;
|
|
}
|
|
|
|
void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts,
|
|
unsigned char uniquify_names) {
|
|
opts->opts.uniquify_names = uniquify_names;
|
|
}
|
|
|
|
void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts,
|
|
unsigned char uniquify_prefix) {
|
|
opts->opts.uniquify_prefix = uniquify_prefix;
|
|
}
|
|
|
|
void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
|
|
const char* src_name,
|
|
int src_index, TF_Output dst) {
|
|
opts->tensor_id_data.push_back(src_name);
|
|
const string& src_name_str = opts->tensor_id_data.back();
|
|
// We don't need to store dst's name in tensor_id_data, since `dst` must
|
|
// outlive the ImportGraphDef call.
|
|
opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst);
|
|
}
|
|
|
|
void TF_ImportGraphDefOptionsRemapControlDependency(
|
|
TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) {
|
|
opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] =
|
|
TensorId(dst->node.name(), tensorflow::Graph::kControlSlot);
|
|
}
|
|
|
|
extern void TF_ImportGraphDefOptionsAddControlDependency(
|
|
TF_ImportGraphDefOptions* opts, TF_Operation* oper) {
|
|
opts->opts.control_dependencies.push_back(oper->node.name());
|
|
}
|
|
|
|
void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts,
|
|
const char* oper_name, int index) {
|
|
opts->tensor_id_data.push_back(oper_name);
|
|
const string& oper_name_str = opts->tensor_id_data.back();
|
|
opts->opts.return_tensors.emplace_back(oper_name_str, index);
|
|
}
|
|
|
|
int TF_ImportGraphDefOptionsNumReturnOutputs(
|
|
const TF_ImportGraphDefOptions* opts) {
|
|
return opts->opts.return_tensors.size();
|
|
}
|
|
|
|
void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts,
|
|
const char* oper_name) {
|
|
opts->opts.return_nodes.push_back(oper_name);
|
|
}
|
|
|
|
int TF_ImportGraphDefOptionsNumReturnOperations(
|
|
const TF_ImportGraphDefOptions* opts) {
|
|
return opts->opts.return_nodes.size();
|
|
}
|
|
|
|
void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results,
|
|
int* num_outputs,
|
|
TF_Output** outputs) {
|
|
*num_outputs = results->return_tensors.size();
|
|
*outputs = results->return_tensors.data();
|
|
}
|
|
|
|
void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
|
|
int* num_opers,
|
|
TF_Operation*** opers) {
|
|
*num_opers = results->return_nodes.size();
|
|
*opers = results->return_nodes.data();
|
|
}
|
|
|
|
void TF_ImportGraphDefResultsMissingUnusedInputMappings(
|
|
TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings,
|
|
const char*** src_names, int** src_indexes) {
|
|
*num_missing_unused_input_mappings = results->missing_unused_key_names.size();
|
|
*src_names = results->missing_unused_key_names.data();
|
|
*src_indexes = results->missing_unused_key_indexes.data();
|
|
}
|
|
|
|
void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
|
|
delete results;
|
|
}
|
|
|
|
static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
|
|
const TF_ImportGraphDefOptions* opts,
|
|
TF_ImportGraphDefResults* tf_results,
|
|
TF_Status* status)
|
|
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
|
|
const int last_node_id = graph->graph.num_node_ids();
|
|
tensorflow::ImportGraphDefResults results;
|
|
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
|
|
&graph->refiner, &results);
|
|
if (!status->status.ok()) return;
|
|
|
|
// Add new nodes to name_map
|
|
for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
|
|
auto* node = graph->graph.FindNodeId(i);
|
|
if (node != nullptr) graph->name_map[node->name()] = node;
|
|
}
|
|
|
|
// Populate return_tensors
|
|
DCHECK(tf_results->return_tensors.empty());
|
|
tf_results->return_tensors.resize(results.return_tensors.size());
|
|
for (int i = 0; i < results.return_tensors.size(); ++i) {
|
|
tf_results->return_tensors[i].oper =
|
|
ToOperation(results.return_tensors[i].first);
|
|
tf_results->return_tensors[i].index = results.return_tensors[i].second;
|
|
}
|
|
|
|
// Populate return_nodes
|
|
DCHECK(tf_results->return_nodes.empty());
|
|
tf_results->return_nodes.resize(results.return_nodes.size());
|
|
for (int i = 0; i < results.return_nodes.size(); ++i) {
|
|
tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
|
|
}
|
|
|
|
// Populate missing unused map keys
|
|
DCHECK(tf_results->missing_unused_key_names.empty());
|
|
DCHECK(tf_results->missing_unused_key_indexes.empty());
|
|
DCHECK(tf_results->missing_unused_key_names_data.empty());
|
|
|
|
size_t size = results.missing_unused_input_map_keys.size();
|
|
tf_results->missing_unused_key_names.resize(size);
|
|
tf_results->missing_unused_key_indexes.resize(size);
|
|
|
|
for (int i = 0; i < size; ++i) {
|
|
TensorId id = results.missing_unused_input_map_keys[i];
|
|
tf_results->missing_unused_key_names_data.emplace_back(id.first);
|
|
tf_results->missing_unused_key_names[i] =
|
|
tf_results->missing_unused_key_names_data.back().c_str();
|
|
tf_results->missing_unused_key_indexes[i] = id.second;
|
|
}
|
|
}
|
|
|
|
TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
|
|
TF_Graph* graph, const TF_Buffer* graph_def,
|
|
const TF_ImportGraphDefOptions* options, TF_Status* status) {
|
|
GraphDef def;
|
|
if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
|
|
graph_def->length)) {
|
|
status->status = InvalidArgument("Invalid GraphDef");
|
|
return nullptr;
|
|
}
|
|
auto results = new TF_ImportGraphDefResults();
|
|
mutex_lock l(graph->mu);
|
|
GraphImportGraphDefLocked(graph, def, options, results, status);
|
|
if (!status->status.ok()) {
|
|
delete results;
|
|
return nullptr;
|
|
}
|
|
return results;
|
|
}
|
|
|
|
void TF_GraphImportGraphDefWithReturnOutputs(
|
|
TF_Graph* graph, const TF_Buffer* graph_def,
|
|
const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
|
|
int num_return_outputs, TF_Status* status) {
|
|
if (num_return_outputs != options->opts.return_tensors.size()) {
|
|
status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
|
|
options->opts.return_tensors.size(),
|
|
", got ", num_return_outputs);
|
|
return;
|
|
}
|
|
if (num_return_outputs > 0 && return_outputs == nullptr) {
|
|
status->status = InvalidArgument(
|
|
"'return_outputs' must be preallocated to length ", num_return_outputs);
|
|
return;
|
|
}
|
|
GraphDef def;
|
|
if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
|
|
graph_def->length)) {
|
|
status->status = InvalidArgument("Invalid GraphDef");
|
|
return;
|
|
}
|
|
TF_ImportGraphDefResults results;
|
|
mutex_lock l(graph->mu);
|
|
GraphImportGraphDefLocked(graph, def, options, &results, status);
|
|
DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
|
|
memcpy(return_outputs, results.return_tensors.data(),
|
|
num_return_outputs * sizeof(TF_Output));
|
|
}
|
|
|
|
void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
|
|
const TF_ImportGraphDefOptions* options,
|
|
TF_Status* status) {
|
|
TF_ImportGraphDefResults* results =
|
|
TF_GraphImportGraphDefWithResults(graph, graph_def, options, status);
|
|
TF_DeleteImportGraphDefResults(results);
|
|
}
|
|
|
|
// While loop functions -------------------------------------------------------
|
|
|
|
namespace {
|
|
|
|
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
|
|
|
// Creates a placeholder representing an input to the cond or body graph.
|
|
// TODO(skyewm): remove these from final graph
|
|
bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
|
|
TF_Output* input, TF_Status* status) {
|
|
TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name);
|
|
TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
|
|
// TODO(skyewm): set placeholder shape
|
|
TF_Operation* oper = TF_FinishOperation(desc, status);
|
|
if (!status->status.ok()) return false;
|
|
*input = {oper, 0};
|
|
return true;
|
|
}
|
|
|
|
// Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input
|
|
// `src_inputs[i]` will have that input replaced with `dst_inputs[i]`. `prefix`
|
|
// will be prepended to copied node names. `control_deps` are nodes in
|
|
// `dst_graph` that the copied `src_graph` nodes will have control dependencies
|
|
// on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes
|
|
// in `dst_graph` will be returned. `return_nodes` must be non-null.
|
|
Status CopyGraph(Graph* src_graph, Graph* dst_graph,
|
|
tensorflow::ShapeRefiner* dst_refiner,
|
|
const TF_Output* src_inputs,
|
|
const std::vector<tensorflow::Output>& dst_inputs,
|
|
const string& prefix,
|
|
const std::vector<tensorflow::Operation>& control_deps,
|
|
const TF_Output* nodes_to_return, int nreturn_nodes,
|
|
std::vector<tensorflow::Output>* return_nodes) {
|
|
DCHECK(return_nodes != nullptr);
|
|
GraphDef gdef;
|
|
src_graph->ToGraphDef(&gdef);
|
|
|
|
tensorflow::ImportGraphDefOptions opts;
|
|
opts.prefix = prefix;
|
|
|
|
for (int i = 0; i < dst_inputs.size(); ++i) {
|
|
opts.input_map[ToTensorId(src_inputs[i])] =
|
|
TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index());
|
|
}
|
|
opts.skip_mapped_nodes = true;
|
|
|
|
for (const tensorflow::Operation& op : control_deps) {
|
|
opts.control_dependencies.push_back(op.node()->name());
|
|
}
|
|
|
|
for (int i = 0; i < nreturn_nodes; ++i) {
|
|
opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
|
|
}
|
|
|
|
// TODO(skyewm): change to OutputTensor
|
|
tensorflow::ImportGraphDefResults results;
|
|
TF_RETURN_IF_ERROR(
|
|
ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
|
|
|
|
for (const auto& pair : results.return_tensors) {
|
|
return_nodes->emplace_back(pair.first, pair.second);
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) {
|
|
if (params.cond_graph == nullptr || params.body_graph == nullptr ||
|
|
params.cond_graph->parent == nullptr ||
|
|
params.cond_graph->parent != params.body_graph->parent ||
|
|
params.cond_graph->parent_inputs != params.body_graph->parent_inputs ||
|
|
params.ninputs <= 0 || params.cond_inputs == nullptr ||
|
|
params.body_inputs == nullptr || params.body_outputs == nullptr) {
|
|
s->status = InvalidArgument(
|
|
"TF_WhileParams must be created by successful TF_NewWhile() call");
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) {
|
|
if (params.cond_output.oper == nullptr) {
|
|
s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set");
|
|
return false;
|
|
}
|
|
for (int i = 0; i < params.ninputs; ++i) {
|
|
if (params.body_outputs[i].oper == nullptr) {
|
|
s->status = InvalidArgument("TF_WhileParams `body_outputs[", i, "]` ",
|
|
"field isn't set");
|
|
return false;
|
|
}
|
|
}
|
|
if (params.name == nullptr) {
|
|
s->status = InvalidArgument("TF_WhileParams `name` field is null");
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
|
|
|
void FreeWhileResources(const TF_WhileParams* params) {
|
|
TF_DeleteGraph(params->cond_graph);
|
|
TF_DeleteGraph(params->body_graph);
|
|
delete[] params->cond_inputs;
|
|
delete[] params->body_inputs;
|
|
delete[] params->body_outputs;
|
|
}
|
|
|
|
TF_WhileParams EmptyWhileParams() {
|
|
return {0, nullptr, nullptr, {nullptr, 0},
|
|
nullptr, nullptr, nullptr, nullptr};
|
|
}
|
|
|
|
} // namespace
|
|
|
|
TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
|
|
TF_Status* status) {
|
|
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
status->status = tensorflow::errors::Unimplemented(
|
|
"Creating while loops is not supported on mobile. File a bug at "
|
|
"https://github.com/tensorflow/tensorflow/issues if this feature is "
|
|
"important to you");
|
|
return EmptyWhileParams();
|
|
#else
|
|
if (ninputs == 0) {
|
|
status->status =
|
|
InvalidArgument("TF_NewWhile() must be passed at least one input");
|
|
return EmptyWhileParams();
|
|
}
|
|
|
|
TF_Graph* cond_graph = TF_NewGraph();
|
|
TF_Graph* body_graph = TF_NewGraph();
|
|
cond_graph->parent = g;
|
|
cond_graph->parent_inputs = inputs;
|
|
body_graph->parent = g;
|
|
body_graph->parent_inputs = inputs;
|
|
|
|
TF_Output* cond_inputs = new TF_Output[ninputs];
|
|
TF_Output cond_output = {nullptr, -1};
|
|
TF_Output* body_inputs = new TF_Output[ninputs];
|
|
TF_Output* body_outputs = new TF_Output[ninputs];
|
|
for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1};
|
|
const char* name = nullptr;
|
|
|
|
for (int i = 0; i < ninputs; ++i) {
|
|
// TODO(skyewm): prefix names with underscore (requires some plumbing)
|
|
if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input", i).c_str(),
|
|
&cond_inputs[i], status)) {
|
|
break;
|
|
}
|
|
if (!CreateInput(inputs[i], body_graph, StrCat("body_input", i).c_str(),
|
|
&body_inputs[i], status)) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output,
|
|
body_graph, body_inputs, body_outputs, name};
|
|
|
|
if (!status->status.ok()) {
|
|
FreeWhileResources(¶ms);
|
|
return EmptyWhileParams();
|
|
}
|
|
return params;
|
|
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
}
|
|
|
|
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
|
namespace {
|
|
|
|
// TODO(skyewm): make nodes in while loop unfetchable like in Python version
|
|
void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status,
|
|
TF_Output* outputs) {
|
|
if (!ValidateInputWhileParams(*params, status)) return;
|
|
|
|
TF_Graph* parent = params->cond_graph->parent;
|
|
TF_Output* parent_inputs = params->cond_graph->parent_inputs;
|
|
int num_loop_vars = params->ninputs;
|
|
|
|
mutex_lock l(parent->mu);
|
|
|
|
// 'cond_fn' copies the cond graph into the parent graph.
|
|
tensorflow::ops::CondGraphBuilderFn cond_fn =
|
|
[params, parent](const tensorflow::Scope& scope,
|
|
const std::vector<tensorflow::Output>& inputs,
|
|
tensorflow::Output* output) {
|
|
DCHECK_EQ(scope.graph(), &parent->graph);
|
|
std::vector<tensorflow::Output> cond_output;
|
|
TF_RETURN_IF_ERROR(CopyGraph(
|
|
¶ms->cond_graph->graph, &parent->graph, &parent->refiner,
|
|
params->cond_inputs, inputs, scope.impl()->name(),
|
|
scope.impl()->control_deps(), ¶ms->cond_output,
|
|
/* nreturn_nodes */ 1, &cond_output));
|
|
*output = cond_output[0];
|
|
return Status::OK();
|
|
};
|
|
|
|
// 'body_fn' copies the body graph into the parent graph.
|
|
tensorflow::ops::BodyGraphBuilderFn body_fn =
|
|
[params, parent, num_loop_vars](
|
|
const tensorflow::Scope& scope,
|
|
const std::vector<tensorflow::Output>& inputs,
|
|
std::vector<tensorflow::Output>* outputs) {
|
|
DCHECK_EQ(scope.graph(), &parent->graph);
|
|
TF_RETURN_IF_ERROR(
|
|
CopyGraph(¶ms->body_graph->graph, &parent->graph,
|
|
&parent->refiner, params->body_inputs, inputs,
|
|
scope.impl()->name(), scope.impl()->control_deps(),
|
|
params->body_outputs, num_loop_vars, outputs));
|
|
return Status::OK();
|
|
};
|
|
|
|
// Create the while loop using an internal scope.
|
|
tensorflow::Scope scope =
|
|
NewInternalScope(&parent->graph, &status->status, &parent->refiner)
|
|
.NewSubScope(params->name);
|
|
|
|
const int first_new_node_id = parent->graph.num_node_ids();
|
|
|
|
tensorflow::OutputList loop_outputs;
|
|
status->status = tensorflow::ops::BuildWhileLoop(
|
|
scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn,
|
|
body_fn, params->name, &loop_outputs);
|
|
|
|
// Update name_map with newly-created ops.
|
|
// TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns
|
|
// a bad status. Once we fix this, we may want to return early instead of
|
|
// executing the following code.
|
|
for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) {
|
|
Node* new_node = parent->graph.FindNodeId(i);
|
|
if (new_node == nullptr) continue;
|
|
parent->name_map[new_node->name()] = new_node;
|
|
}
|
|
|
|
// Populate 'outputs'.
|
|
DCHECK_LE(loop_outputs.size(), num_loop_vars);
|
|
for (int i = 0; i < loop_outputs.size(); ++i) {
|
|
outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()};
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
|
|
|
void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
|
|
TF_Output* outputs) {
|
|
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
status->status = tensorflow::errors::Unimplemented(
|
|
"Creating while loops is not supported on mobile. File a bug at "
|
|
"https://github.com/tensorflow/tensorflow/issues if this feature is "
|
|
"important to you");
|
|
#else
|
|
// If it appears the caller created or modified `params`, don't free resources
|
|
if (!ValidateConstWhileParams(*params, status)) return;
|
|
TF_FinishWhileHelper(params, status, outputs);
|
|
FreeWhileResources(params);
|
|
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
}
|
|
|
|
void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
|
|
|
|
void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
|
|
TF_Output* dx, TF_Status* status, TF_Output* dy) {
|
|
TF_AddGradientsWithPrefix(g, nullptr, y, ny, x, nx, dx, status, dy);
|
|
}
|
|
|
|
void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
|
|
int ny, TF_Output* x, int nx, TF_Output* dx,
|
|
TF_Status* status, TF_Output* dy) {
|
|
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
status->status = tensorflow::errors::Unimplemented(
|
|
"Adding gradients is not supported on mobile. File a bug at "
|
|
"https://github.com/tensorflow/tensorflow/issues if this feature is "
|
|
"important to you");
|
|
#else
|
|
std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny);
|
|
std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx);
|
|
std::vector<tensorflow::Output> dy_arg;
|
|
|
|
{
|
|
// We need to hold on to the lock while we have a scope that uses TF_Graph.
|
|
mutex_lock graph_lock(g->mu);
|
|
|
|
const int first_new_node_id = g->graph.num_node_ids();
|
|
|
|
string prefix_cmp;
|
|
const char* child_scope_name;
|
|
if (prefix == nullptr) {
|
|
child_scope_name = "gradients";
|
|
} else {
|
|
prefix_cmp = string(prefix) + "/";
|
|
// The operation should fail if the provided name prefix has already been
|
|
// used in this graph
|
|
for (const auto& pair : g->name_map) {
|
|
const string& name = pair.first;
|
|
if ((name == prefix) || absl::StartsWith(name, prefix_cmp)) {
|
|
status->status = InvalidArgument(
|
|
"prefix [", prefix,
|
|
"] conflicts with existing node in the graph named [", name, "]");
|
|
return;
|
|
}
|
|
}
|
|
child_scope_name = prefix;
|
|
}
|
|
tensorflow::Scope scope =
|
|
NewInternalScope(&g->graph, &status->status, &g->refiner)
|
|
.NewSubScope(child_scope_name);
|
|
|
|
if (dx != nullptr) {
|
|
std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny);
|
|
status->status =
|
|
AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg);
|
|
} else {
|
|
status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg);
|
|
}
|
|
|
|
// Update g->name_map with the name_map from the scope, which will contain
|
|
// the new gradient ops.
|
|
for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
|
|
Node* n = g->graph.FindNodeId(i);
|
|
if (n == nullptr) continue;
|
|
|
|
// Adding the gradients to the graph can alter the prefix to prevent
|
|
// name collisions only if this prefix has not been provided explicitly
|
|
// by the user. If it was provided, assert that it remained intact.
|
|
if (prefix != nullptr && !absl::StartsWith(n->name(), prefix_cmp)) {
|
|
status->status = tensorflow::errors::Internal(
|
|
"BUG: The gradients prefix have been unexpectedly altered when "
|
|
"adding the nodes to the graph. This is a bug. Please file an "
|
|
"issue at https://github.com/tensorflow/tensorflow/issues.");
|
|
return;
|
|
}
|
|
// We have a convoluted scheme here: Using the C++ graph construction API
|
|
// to add potentially many nodes to the graph without running the checks
|
|
// (such as uniqueness of the names of nodes) we run with other functions
|
|
// that add a node to the graph (like TF_FinishOperation).
|
|
if (!g->name_map.insert(std::make_pair(n->name(), n)).second) {
|
|
status->status = tensorflow::errors::Internal(
|
|
"BUG: The API allowed construction of a graph with duplicate node "
|
|
"names (",
|
|
n->name(),
|
|
"). This is a bug. Please file an issue at "
|
|
"https://github.com/tensorflow/tensorflow/issues.");
|
|
}
|
|
}
|
|
}
|
|
|
|
// Unpack the results from grad_outputs_arg.
|
|
TFOutputsFromOutputs(dy_arg, dy);
|
|
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
}
|
|
|
|
// TF_Session functions ----------------------------------------------
|
|
|
|
TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g)
|
|
: session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {}
|
|
|
|
TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
|
|
TF_Status* status) {
|
|
Session* session;
|
|
status->status = NewSession(opt->options, &session);
|
|
if (status->status.ok()) {
|
|
TF_Session* new_session = new TF_Session(session, graph);
|
|
if (graph != nullptr) {
|
|
mutex_lock l(graph->mu);
|
|
graph->sessions[new_session] = "";
|
|
}
|
|
return new_session;
|
|
} else {
|
|
LOG(ERROR) << status->status;
|
|
DCHECK_EQ(nullptr, session);
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
TF_Session* TF_LoadSessionFromSavedModel(
|
|
const TF_SessionOptions* session_options, const TF_Buffer* run_options,
|
|
const char* export_dir, const char* const* tags, int tags_len,
|
|
TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) {
|
|
// TODO(sjr): Remove the IS_MOBILE_PLATFORM guard. This will require ensuring
|
|
// that the tensorflow/cc/saved_model:loader build target is mobile friendly.
|
|
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
status->status = tensorflow::errors::Unimplemented(
|
|
"Loading a SavedModel is not supported on mobile. File a bug at "
|
|
"https://github.com/tensorflow/tensorflow/issues if this feature is "
|
|
"important to you");
|
|
return nullptr;
|
|
#else
|
|
mutex_lock l(graph->mu);
|
|
if (!graph->name_map.empty()) {
|
|
status->status = InvalidArgument("Graph is non-empty.");
|
|
return nullptr;
|
|
}
|
|
|
|
RunOptions run_options_proto;
|
|
if (run_options != nullptr && !run_options_proto.ParseFromArray(
|
|
run_options->data, run_options->length)) {
|
|
status->status = InvalidArgument("Unparseable RunOptions proto");
|
|
return nullptr;
|
|
}
|
|
|
|
std::unordered_set<string> tag_set;
|
|
for (int i = 0; i < tags_len; i++) {
|
|
tag_set.insert(string(tags[i]));
|
|
}
|
|
|
|
tensorflow::SavedModelBundle bundle;
|
|
status->status =
|
|
tensorflow::LoadSavedModel(session_options->options, run_options_proto,
|
|
export_dir, tag_set, &bundle);
|
|
if (!status->status.ok()) return nullptr;
|
|
|
|
// Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
|
|
// extends using GraphDefs. The Graph instance is different, but equivalent
|
|
// to the one used to create the session.
|
|
//
|
|
// TODO(jhseu): When Session is modified to take Graphs instead of
|
|
// GraphDefs, return the Graph generated in LoadSavedModel().
|
|
TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
|
|
TF_ImportGraphDefResults results;
|
|
GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
|
|
import_opts, &results, status);
|
|
TF_DeleteImportGraphDefOptions(import_opts);
|
|
if (!status->status.ok()) return nullptr;
|
|
|
|
if (meta_graph_def != nullptr) {
|
|
status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
|
|
if (!status->status.ok()) return nullptr;
|
|
}
|
|
|
|
TF_Session* session = new TF_Session(bundle.session.release(), graph);
|
|
|
|
graph->sessions[session] = "";
|
|
session->last_num_graph_nodes = graph->graph.num_node_ids();
|
|
return session;
|
|
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
}
|
|
|
|
void TF_CloseSession(TF_Session* s, TF_Status* status) {
|
|
status->status = s->session->Close();
|
|
}
|
|
|
|
void TF_DeleteSession(TF_Session* s, TF_Status* status) {
|
|
status->status = Status::OK();
|
|
if (s == nullptr) return;
|
|
TF_Graph* const graph = s->graph;
|
|
if (graph != nullptr) {
|
|
graph->mu.lock();
|
|
graph->sessions.erase(s);
|
|
const bool del = graph->delete_requested && graph->sessions.empty();
|
|
graph->mu.unlock();
|
|
if (del) delete graph;
|
|
}
|
|
delete s->session;
|
|
delete s;
|
|
}
|
|
|
|
void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
|
|
const TF_Output* inputs, TF_Tensor* const* input_values,
|
|
int ninputs, const TF_Output* outputs,
|
|
TF_Tensor** output_values, int noutputs,
|
|
const TF_Operation* const* target_opers, int ntargets,
|
|
TF_Buffer* run_metadata, TF_Status* status) {
|
|
// TODO(josh11b,mrry): Change Session to be able to use a Graph*
|
|
// directly, instead of requiring us to serialize to a GraphDef and
|
|
// call Session::Extend().
|
|
if (session->extend_before_run &&
|
|
!ExtendSessionGraphHelper(session, status)) {
|
|
return;
|
|
}
|
|
|
|
TF_Run_Setup(noutputs, output_values, status);
|
|
|
|
// Convert from TF_Output and TF_Tensor to a string and Tensor.
|
|
std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
|
|
if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
|
|
for (int i = 0; i < ninputs; ++i) {
|
|
input_pairs[i].first = OutputName(inputs[i]);
|
|
}
|
|
|
|
// Convert from TF_Output to string names.
|
|
std::vector<string> output_names(noutputs);
|
|
for (int i = 0; i < noutputs; ++i) {
|
|
output_names[i] = OutputName(outputs[i]);
|
|
}
|
|
|
|
// Convert from TF_Operation* to string names.
|
|
std::vector<string> target_names(ntargets);
|
|
for (int i = 0; i < ntargets; ++i) {
|
|
target_names[i] = target_opers[i]->node.name();
|
|
}
|
|
|
|
// Actually run.
|
|
TF_Run_Helper(session->session, nullptr, run_options, input_pairs,
|
|
output_names, output_values, target_names, run_metadata,
|
|
status);
|
|
}
|
|
|
|
void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
|
|
int ninputs, const TF_Output* outputs, int noutputs,
|
|
const TF_Operation* const* target_opers, int ntargets,
|
|
const char** handle, TF_Status* status) {
|
|
*handle = nullptr;
|
|
|
|
if (session->extend_before_run &&
|
|
!ExtendSessionGraphHelper(session, status)) {
|
|
return;
|
|
}
|
|
|
|
std::vector<string> input_names(ninputs);
|
|
for (int i = 0; i < ninputs; ++i) {
|
|
input_names[i] = OutputName(inputs[i]);
|
|
}
|
|
|
|
std::vector<string> output_names(noutputs);
|
|
for (int i = 0; i < noutputs; ++i) {
|
|
output_names[i] = OutputName(outputs[i]);
|
|
}
|
|
|
|
std::vector<string> target_names(ntargets);
|
|
for (int i = 0; i < ntargets; ++i) {
|
|
target_names[i] = target_opers[i]->node.name();
|
|
}
|
|
|
|
string new_handle;
|
|
status->status = session->session->PRunSetup(input_names, output_names,
|
|
target_names, &new_handle);
|
|
if (status->status.ok()) {
|
|
char* buf = new char[new_handle.size() + 1];
|
|
memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
|
|
*handle = buf;
|
|
}
|
|
}
|
|
|
|
void TF_DeletePRunHandle(const char* handle) {
|
|
delete[] handle;
|
|
// TODO(suharshs): Free up any resources held by the partial run state.
|
|
}
|
|
|
|
void TF_SessionPRun(TF_Session* session, const char* handle,
|
|
const TF_Output* inputs, TF_Tensor* const* input_values,
|
|
int ninputs, const TF_Output* outputs,
|
|
TF_Tensor** output_values, int noutputs,
|
|
const TF_Operation* const* target_opers, int ntargets,
|
|
TF_Status* status) {
|
|
// TODO(josh11b,mrry): Change Session to be able to use a Graph*
|
|
// directly, instead of requiring us to serialize to a GraphDef and
|
|
// call Session::Extend().
|
|
if (session->extend_before_run &&
|
|
!ExtendSessionGraphHelper(session, status)) {
|
|
return;
|
|
}
|
|
|
|
TF_Run_Setup(noutputs, output_values, status);
|
|
|
|
// Convert from TF_Output and TF_Tensor to a string and Tensor.
|
|
std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
|
|
if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
|
|
for (int i = 0; i < ninputs; ++i) {
|
|
input_pairs[i].first = OutputName(inputs[i]);
|
|
}
|
|
|
|
// Convert from TF_Output to string names.
|
|
std::vector<string> output_names(noutputs);
|
|
for (int i = 0; i < noutputs; ++i) {
|
|
output_names[i] = OutputName(outputs[i]);
|
|
}
|
|
|
|
// Convert from TF_Operation* to string names.
|
|
std::vector<string> target_names(ntargets);
|
|
for (int i = 0; i < ntargets; ++i) {
|
|
target_names[i] = target_opers[i]->node.name();
|
|
}
|
|
|
|
TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names,
|
|
output_values, target_names, nullptr, status);
|
|
}
|
|
|
|
unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
|
|
TF_Tensor** result, TF_Status* status) {
|
|
*result = nullptr;
|
|
mutex_lock l(graph->mu);
|
|
OutputTensor tensor(&output.oper->node, output.index);
|
|
bool evaluated;
|
|
Tensor result_tensor;
|
|
status->status = EvaluateConstantTensor(
|
|
tensor, graph->refiner, *graph->graph.op_registry(),
|
|
graph->graph.versions().producer(), &evaluated, &result_tensor);
|
|
if (evaluated) {
|
|
DCHECK(status->status.ok());
|
|
*result = TF_TensorFromTensor(result_tensor, &status->status);
|
|
if (!status->status.ok()) evaluated = false;
|
|
}
|
|
return evaluated;
|
|
}
|
|
|
|
TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
|
|
tensorflow::OpList op_list;
|
|
if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
|
|
status->status = InvalidArgument("Unparseable OpList");
|
|
return nullptr;
|
|
}
|
|
status->status = Status::OK();
|
|
return new TF_ApiDefMap(op_list);
|
|
}
|
|
|
|
void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; }
|
|
|
|
void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text,
|
|
size_t text_len, TF_Status* status) {
|
|
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
status->status = tensorflow::errors::Unimplemented(
|
|
"ApiDefMap is not supported on mobile.");
|
|
#else
|
|
mutex_lock l(api_def_map->lock);
|
|
if (api_def_map->update_docs_called) {
|
|
status->status = FailedPrecondition(
|
|
"TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been "
|
|
"called.");
|
|
return;
|
|
}
|
|
string api_def_text(text, text_len);
|
|
status->status = api_def_map->api_def_map.LoadApiDef(api_def_text);
|
|
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
}
|
|
|
|
TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
|
|
size_t name_len, TF_Status* status) {
|
|
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
status->status = tensorflow::errors::Unimplemented(
|
|
"ApiDefMap is not supported on mobile.");
|
|
return nullptr;
|
|
#else
|
|
mutex_lock l(api_def_map->lock);
|
|
if (!api_def_map->update_docs_called) {
|
|
api_def_map->api_def_map.UpdateDocs();
|
|
api_def_map->update_docs_called = true;
|
|
}
|
|
string name_str(name, name_len);
|
|
const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
|
|
if (api_def == nullptr) {
|
|
return nullptr;
|
|
}
|
|
|
|
TF_Buffer* ret = TF_NewBuffer();
|
|
status->status = MessageToBuffer(*api_def, ret);
|
|
if (!status->status.ok()) {
|
|
TF_DeleteBuffer(ret);
|
|
return nullptr;
|
|
}
|
|
return ret;
|
|
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
}
|
|
|
|
TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) {
|
|
tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels();
|
|
TF_Buffer* ret = TF_NewBuffer();
|
|
status->status = MessageToBuffer(kernel_list, ret);
|
|
if (!status->status.ok()) {
|
|
TF_DeleteBuffer(ret);
|
|
return nullptr;
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
|
|
tensorflow::KernelList kernel_list =
|
|
tensorflow::GetRegisteredKernelsForOp(name);
|
|
TF_Buffer* ret = TF_NewBuffer();
|
|
status->status = MessageToBuffer(kernel_list, ret);
|
|
if (!status->status.ok()) {
|
|
TF_DeleteBuffer(ret);
|
|
return nullptr;
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
// TF_Server functions ----------------------------------------------
|
|
|
|
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
|
TF_Server::TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)
|
|
: target(server->target()), server(std::move(server)) {}
|
|
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
|
|
|
TF_Server* TF_NewServer(const void* proto, size_t proto_len,
|
|
TF_Status* status) {
|
|
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
status->status = tensorflow::errors::Unimplemented(
|
|
"Server functionality is not supported on mobile");
|
|
return nullptr;
|
|
#else
|
|
tensorflow::ServerDef server_def;
|
|
if (!server_def.ParseFromArray(proto, static_cast<int>(proto_len))) {
|
|
status->status = InvalidArgument(
|
|
"Could not parse provided bytes into a ServerDef protocol buffer");
|
|
return nullptr;
|
|
}
|
|
|
|
std::unique_ptr<tensorflow::ServerInterface> out_server;
|
|
status->status = tensorflow::NewServer(server_def, &out_server);
|
|
if (!status->status.ok()) return nullptr;
|
|
|
|
return new TF_Server(std::move(out_server));
|
|
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
}
|
|
|
|
void TF_ServerStart(TF_Server* server, TF_Status* status) {
|
|
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
status->status = tensorflow::errors::Unimplemented(
|
|
"Server functionality is not supported on mobile");
|
|
#else
|
|
status->status = server->server->Start();
|
|
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
}
|
|
|
|
void TF_ServerStop(TF_Server* server, TF_Status* status) {
|
|
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
status->status = tensorflow::errors::Unimplemented(
|
|
"Server functionality is not supported on mobile");
|
|
#else
|
|
status->status = server->server->Stop();
|
|
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
}
|
|
|
|
void TF_ServerJoin(TF_Server* server, TF_Status* status) {
|
|
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
status->status = tensorflow::errors::Unimplemented(
|
|
"Server functionality is not supported on mobile");
|
|
#else
|
|
status->status = server->server->Join();
|
|
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
}
|
|
|
|
const char* TF_ServerTarget(TF_Server* server) {
|
|
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
|
return nullptr;
|
|
#else
|
|
return server->target.c_str();
|
|
#endif
|
|
}
|
|
|
|
void TF_DeleteServer(TF_Server* server) {
|
|
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
|
delete server;
|
|
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
|
}
|
|
|
|
void TF_RegisterLogListener(void (*listener)(const char*)) {
|
|
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
|
tensorflow::logging::RegisterListener(listener);
|
|
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
|
}
|
|
|
|
} // end extern "C"
|