Revert "Merge branch 'r0.10' of https://github.com/tensorflow/tensorflow"
This reverts commita3539967e2
, reversing changes made toee221cb625
.
This commit is contained in:
parent
51e9756b62
commit
d027643a74
@ -441,7 +441,7 @@ static void TF_Run_Helper(
|
||||
const std::vector<tensorflow::string>& output_tensor_names,
|
||||
TF_Tensor** c_outputs,
|
||||
// Target nodes
|
||||
const std::vector<tensorflow::string>& target_oper_names,
|
||||
const std::vector<tensorflow::string>& target_node_names,
|
||||
TF_Buffer* run_metadata, TF_Status* status) {
|
||||
const int noutputs = output_tensor_names.size();
|
||||
std::vector<Tensor> outputs(noutputs);
|
||||
@ -464,7 +464,7 @@ static void TF_Run_Helper(
|
||||
|
||||
RunMetadata run_metadata_proto;
|
||||
result = session->Run(run_options_proto, input_pairs, output_tensor_names,
|
||||
target_oper_names, &outputs, &run_metadata_proto);
|
||||
target_node_names, &outputs, &run_metadata_proto);
|
||||
|
||||
// Serialize back to upstream client, who now owns the new buffer
|
||||
if (run_metadata != nullptr) {
|
||||
@ -512,9 +512,10 @@ void TF_Run(TF_Session* s, const TF_Buffer* run_options,
|
||||
// Input tensors
|
||||
const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
|
||||
// Output tensors
|
||||
const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
|
||||
const char** c_output_tensor_names, TF_Tensor** c_outputs,
|
||||
int noutputs,
|
||||
// Target nodes
|
||||
const char** c_target_oper_names, int ntargets,
|
||||
const char** c_target_node_names, int ntargets,
|
||||
TF_Buffer* run_metadata, TF_Status* status) {
|
||||
TF_Run_Setup(noutputs, c_outputs, status);
|
||||
std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs);
|
||||
@ -522,44 +523,45 @@ void TF_Run(TF_Session* s, const TF_Buffer* run_options,
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
input_pairs[i].first = c_input_names[i];
|
||||
}
|
||||
std::vector<tensorflow::string> output_names(noutputs);
|
||||
std::vector<tensorflow::string> output_tensor_names(noutputs);
|
||||
for (int i = 0; i < noutputs; ++i) {
|
||||
output_names[i] = c_output_names[i];
|
||||
output_tensor_names[i] = c_output_tensor_names[i];
|
||||
}
|
||||
std::vector<tensorflow::string> target_oper_names(ntargets);
|
||||
std::vector<tensorflow::string> target_node_names(ntargets);
|
||||
for (int i = 0; i < ntargets; ++i) {
|
||||
target_oper_names[i] = c_target_oper_names[i];
|
||||
target_node_names[i] = c_target_node_names[i];
|
||||
}
|
||||
TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
|
||||
c_outputs, target_oper_names, run_metadata, status);
|
||||
TF_Run_Helper(s->session, nullptr, run_options, input_pairs,
|
||||
output_tensor_names, c_outputs, target_node_names, run_metadata,
|
||||
status);
|
||||
}
|
||||
|
||||
void TF_PRunSetup(TF_Session* s,
|
||||
// Input names
|
||||
const char** c_input_names, int ninputs,
|
||||
// Output names
|
||||
const char** c_output_names, int noutputs,
|
||||
const char** c_output_tensor_names, int noutputs,
|
||||
// Target nodes
|
||||
const char** c_target_oper_names, int ntargets,
|
||||
const char** c_target_node_names, int ntargets,
|
||||
const char** handle, TF_Status* status) {
|
||||
status->status = Status::OK();
|
||||
|
||||
std::vector<tensorflow::string> input_names(ninputs);
|
||||
std::vector<tensorflow::string> output_names(noutputs);
|
||||
std::vector<tensorflow::string> target_oper_names(ntargets);
|
||||
std::vector<tensorflow::string> output_tensor_names(noutputs);
|
||||
std::vector<tensorflow::string> target_node_names(ntargets);
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
input_names[i] = c_input_names[i];
|
||||
}
|
||||
for (int i = 0; i < noutputs; ++i) {
|
||||
output_names[i] = c_output_names[i];
|
||||
output_tensor_names[i] = c_output_tensor_names[i];
|
||||
}
|
||||
for (int i = 0; i < ntargets; ++i) {
|
||||
target_oper_names[i] = c_target_oper_names[i];
|
||||
target_node_names[i] = c_target_node_names[i];
|
||||
}
|
||||
tensorflow::string new_handle;
|
||||
Status result;
|
||||
result = s->session->PRunSetup(input_names, output_names, target_oper_names,
|
||||
&new_handle);
|
||||
result = s->session->PRunSetup(input_names, output_tensor_names,
|
||||
target_node_names, &new_handle);
|
||||
if (result.ok()) {
|
||||
char* buf = new char[new_handle.size() + 1];
|
||||
memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
|
||||
@ -573,9 +575,10 @@ void TF_PRun(TF_Session* s, const char* handle,
|
||||
// Input tensors
|
||||
const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
|
||||
// Output tensors
|
||||
const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
|
||||
const char** c_output_tensor_names, TF_Tensor** c_outputs,
|
||||
int noutputs,
|
||||
// Target nodes
|
||||
const char** c_target_oper_names, int ntargets,
|
||||
const char** c_target_node_names, int ntargets,
|
||||
TF_Status* status) {
|
||||
TF_Run_Setup(noutputs, c_outputs, status);
|
||||
std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs);
|
||||
@ -584,16 +587,16 @@ void TF_PRun(TF_Session* s, const char* handle,
|
||||
input_pairs[i].first = c_input_names[i];
|
||||
}
|
||||
|
||||
std::vector<tensorflow::string> output_names(noutputs);
|
||||
std::vector<tensorflow::string> output_tensor_names(noutputs);
|
||||
for (int i = 0; i < noutputs; ++i) {
|
||||
output_names[i] = c_output_names[i];
|
||||
output_tensor_names[i] = c_output_tensor_names[i];
|
||||
}
|
||||
std::vector<tensorflow::string> target_oper_names(ntargets);
|
||||
std::vector<tensorflow::string> target_node_names(ntargets);
|
||||
for (int i = 0; i < ntargets; ++i) {
|
||||
target_oper_names[i] = c_target_oper_names[i];
|
||||
target_node_names[i] = c_target_node_names[i];
|
||||
}
|
||||
TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names,
|
||||
c_outputs, target_oper_names, nullptr, status);
|
||||
TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_tensor_names,
|
||||
c_outputs, target_node_names, nullptr, status);
|
||||
}
|
||||
|
||||
struct TF_Library {
|
||||
@ -640,16 +643,15 @@ struct TF_Graph {
|
||||
bool delete_requested; // set true by TF_DeleteGraph
|
||||
};
|
||||
|
||||
struct TF_OperationDescription {
|
||||
TF_OperationDescription(TF_Graph* g, const char* op_type,
|
||||
const char* node_name)
|
||||
struct TF_NodeDescription {
|
||||
TF_NodeDescription(TF_Graph* g, const char* op_type, const char* node_name)
|
||||
: node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {}
|
||||
|
||||
NodeBuilder node_builder;
|
||||
TF_Graph* graph;
|
||||
};
|
||||
|
||||
struct TF_Operation {
|
||||
struct TF_Node {
|
||||
Node node;
|
||||
};
|
||||
|
||||
@ -668,56 +670,55 @@ struct TF_SessionWithGraph {
|
||||
|
||||
namespace {
|
||||
|
||||
TF_Operation* ToOperation(Node* node) {
|
||||
return static_cast<TF_Operation*>(static_cast<void*>(node));
|
||||
TF_Node* ToNode(Node* node) {
|
||||
return static_cast<TF_Node*>(static_cast<void*>(node));
|
||||
}
|
||||
|
||||
tensorflow::string PortName(const TF_Port& port) {
|
||||
return tensorflow::strings::StrCat(port.oper->node.name(), ":", port.index);
|
||||
return tensorflow::strings::StrCat(port.node->node.name(), ":", port.index);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// TF_OperationDescription functions
|
||||
// -----------------------------------------------
|
||||
// TF_NodeDescription functions -----------------------------------------------
|
||||
|
||||
extern "C" {
|
||||
|
||||
TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type,
|
||||
const char* oper_name) {
|
||||
TF_NodeDescription* TF_NewNode(TF_Graph* graph, const char* op_type,
|
||||
const char* node_name) {
|
||||
mutex_lock l(graph->mu);
|
||||
return new TF_OperationDescription(graph, op_type, oper_name);
|
||||
return new TF_NodeDescription(graph, op_type, node_name);
|
||||
}
|
||||
|
||||
void TF_SetDevice(TF_OperationDescription* desc, const char* device) {
|
||||
void TF_SetDevice(TF_NodeDescription* desc, const char* device) {
|
||||
desc->node_builder.Device(device);
|
||||
}
|
||||
|
||||
void TF_AddInput(TF_OperationDescription* desc, TF_Port input) {
|
||||
desc->node_builder.Input(&input.oper->node, input.index);
|
||||
void TF_AddInput(TF_NodeDescription* desc, TF_Port input) {
|
||||
desc->node_builder.Input(&input.node->node, input.index);
|
||||
}
|
||||
|
||||
void TF_AddInputList(TF_OperationDescription* desc, const TF_Port* inputs,
|
||||
void TF_AddInputList(TF_NodeDescription* desc, const TF_Port* inputs,
|
||||
int num_inputs) {
|
||||
std::vector<NodeBuilder::NodeOut> input_list;
|
||||
input_list.reserve(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
input_list.emplace_back(&inputs[i].oper->node, inputs[i].index);
|
||||
input_list.emplace_back(&inputs[i].node->node, inputs[i].index);
|
||||
}
|
||||
desc->node_builder.Input(input_list);
|
||||
}
|
||||
|
||||
void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
|
||||
void TF_AddControlInput(TF_NodeDescription* desc, TF_Node* input) {
|
||||
desc->node_builder.ControlInput(&input->node);
|
||||
}
|
||||
|
||||
void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
|
||||
void TF_SetAttrString(TF_NodeDescription* desc, const char* attr_name,
|
||||
const void* value, int length) {
|
||||
tensorflow::StringPiece s(static_cast<const char*>(value), length);
|
||||
desc->node_builder.Attr(attr_name, s);
|
||||
}
|
||||
|
||||
void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
|
||||
void TF_SetAttrStringList(TF_NodeDescription* desc, const char* attr_name,
|
||||
const void* const* values, const int* lengths,
|
||||
int num_values) {
|
||||
std::vector<tensorflow::StringPiece> v;
|
||||
@ -728,14 +729,14 @@ void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
|
||||
desc->node_builder.Attr(attr_name, v);
|
||||
}
|
||||
|
||||
void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
|
||||
void TF_SetAttrInt(TF_NodeDescription* desc, const char* attr_name,
|
||||
int64_t value) {
|
||||
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
||||
"64-bit int types should match in size");
|
||||
desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
|
||||
}
|
||||
|
||||
void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name,
|
||||
void TF_SetAttrIntList(TF_NodeDescription* desc, const char* attr_name,
|
||||
const int64_t* values, int num_values) {
|
||||
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
||||
"64-bit int types should match in size");
|
||||
@ -745,23 +746,23 @@ void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name,
|
||||
reinterpret_cast<const tensorflow::int64*>(values), num_values));
|
||||
}
|
||||
|
||||
void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name,
|
||||
void TF_SetAttrFloat(TF_NodeDescription* desc, const char* attr_name,
|
||||
float value) {
|
||||
desc->node_builder.Attr(attr_name, value);
|
||||
}
|
||||
|
||||
void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name,
|
||||
void TF_SetAttrFloatList(TF_NodeDescription* desc, const char* attr_name,
|
||||
const float* values, int num_values) {
|
||||
desc->node_builder.Attr(attr_name,
|
||||
ArraySlice<const float>(values, num_values));
|
||||
}
|
||||
|
||||
void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
|
||||
void TF_SetAttrBool(TF_NodeDescription* desc, const char* attr_name,
|
||||
unsigned char value) {
|
||||
desc->node_builder.Attr(attr_name, static_cast<bool>(value));
|
||||
}
|
||||
|
||||
void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name,
|
||||
void TF_SetAttrBoolList(TF_NodeDescription* desc, const char* attr_name,
|
||||
const unsigned char* values, int num_values) {
|
||||
bool* b = new bool[num_values];
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
@ -770,19 +771,19 @@ void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name,
|
||||
desc->node_builder.Attr(attr_name, ArraySlice<const bool>(b, num_values));
|
||||
}
|
||||
|
||||
void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
|
||||
void TF_SetAttrType(TF_NodeDescription* desc, const char* attr_name,
|
||||
TF_DataType value) {
|
||||
desc->node_builder.Attr(attr_name, static_cast<DataType>(value));
|
||||
}
|
||||
|
||||
void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
|
||||
void TF_SetAttrTypeList(TF_NodeDescription* desc, const char* attr_name,
|
||||
const TF_DataType* values, int num_values) {
|
||||
desc->node_builder.Attr(
|
||||
attr_name, ArraySlice<const DataType>(
|
||||
reinterpret_cast<const DataType*>(values), num_values));
|
||||
}
|
||||
|
||||
void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name,
|
||||
void TF_SetAttrShape(TF_NodeDescription* desc, const char* attr_name,
|
||||
const int64_t* dims, int num_dims) {
|
||||
PartialTensorShape shape;
|
||||
if (num_dims >= 0) {
|
||||
@ -794,7 +795,7 @@ void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name,
|
||||
desc->node_builder.Attr(attr_name, shape);
|
||||
}
|
||||
|
||||
void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name,
|
||||
void TF_SetAttrShapeList(TF_NodeDescription* desc, const char* attr_name,
|
||||
const int64_t* const* dims, const int* num_dims,
|
||||
int num_shapes) {
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
@ -812,9 +813,8 @@ void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name,
|
||||
desc->node_builder.Attr(attr_name, shapes);
|
||||
}
|
||||
|
||||
void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
|
||||
const char* attr_name, void* proto,
|
||||
int proto_len, TF_Status* status) {
|
||||
void TF_SetAttrTensorShapeProto(TF_NodeDescription* desc, const char* attr_name,
|
||||
void* proto, int proto_len, TF_Status* status) {
|
||||
TensorShapeProto shape;
|
||||
if (shape.ParseFromArray(proto, proto_len)) {
|
||||
desc->node_builder.Attr(attr_name, shape);
|
||||
@ -825,7 +825,7 @@ void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
|
||||
}
|
||||
}
|
||||
|
||||
void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
|
||||
void TF_SetAttrTensorShapeProtoList(TF_NodeDescription* desc,
|
||||
const char* attr_name,
|
||||
const void* const* protos,
|
||||
const int* proto_lens, int num_shapes,
|
||||
@ -843,7 +843,7 @@ void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
|
||||
status->status = Status::OK();
|
||||
}
|
||||
|
||||
void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
|
||||
void TF_SetAttrTensor(TF_NodeDescription* desc, const char* attr_name,
|
||||
TF_Tensor* value, TF_Status* status) {
|
||||
status->status = Status::OK();
|
||||
Tensor t;
|
||||
@ -862,7 +862,7 @@ void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
|
||||
if (ok) desc->node_builder.Attr(attr_name, t);
|
||||
}
|
||||
|
||||
void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
|
||||
void TF_SetAttrTensorList(TF_NodeDescription* desc, const char* attr_name,
|
||||
TF_Tensor* const* values, int num_values,
|
||||
TF_Status* status) {
|
||||
status->status = Status::OK();
|
||||
@ -890,9 +890,9 @@ void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
|
||||
if (ok) desc->node_builder.Attr(attr_name, t);
|
||||
}
|
||||
|
||||
void TF_SetAttrToAttrValueProto(TF_OperationDescription* desc,
|
||||
const char* attr_name, const void* proto,
|
||||
size_t proto_len, TF_Status* status) {
|
||||
void TF_SetAttrToAttrValueProto(TF_NodeDescription* desc, const char* attr_name,
|
||||
const void* proto, size_t proto_len,
|
||||
TF_Status* status) {
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (attr_value.ParseFromArray(proto, proto_len)) {
|
||||
desc->node_builder.Attr(attr_name, attr_value);
|
||||
@ -903,8 +903,7 @@ void TF_SetAttrToAttrValueProto(TF_OperationDescription* desc,
|
||||
}
|
||||
}
|
||||
|
||||
TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
|
||||
TF_Status* status) {
|
||||
TF_Node* TF_FinishNode(TF_NodeDescription* desc, TF_Status* status) {
|
||||
Node* ret = nullptr;
|
||||
mutex_lock l(desc->graph->mu);
|
||||
|
||||
@ -920,37 +919,32 @@ TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
|
||||
|
||||
delete desc;
|
||||
|
||||
return ToOperation(ret);
|
||||
return ToNode(ret);
|
||||
}
|
||||
|
||||
// TF_Operation functions
|
||||
// ----------------------------------------------------------
|
||||
// TF_Node functions ----------------------------------------------------------
|
||||
|
||||
const char* TF_OperationName(TF_Operation* oper) {
|
||||
return oper->node.name().c_str();
|
||||
const char* TF_NodeName(TF_Node* node) { return node->node.name().c_str(); }
|
||||
|
||||
const char* TF_NodeOpType(TF_Node* node) {
|
||||
return node->node.type_string().c_str();
|
||||
}
|
||||
|
||||
const char* TF_OperationOpType(TF_Operation* oper) {
|
||||
return oper->node.type_string().c_str();
|
||||
const char* TF_NodeDevice(TF_Node* node) {
|
||||
return node->node.def().device().c_str();
|
||||
}
|
||||
|
||||
const char* TF_OperationDevice(TF_Operation* oper) {
|
||||
return oper->node.def().device().c_str();
|
||||
}
|
||||
int TF_NodeNumOutputs(TF_Node* node) { return node->node.num_outputs(); }
|
||||
|
||||
int TF_OperationNumOutputs(TF_Operation* oper) {
|
||||
return oper->node.num_outputs();
|
||||
}
|
||||
|
||||
TF_DataType TF_OperationOutputType(TF_Port oper_out) {
|
||||
TF_DataType TF_NodeOutputType(TF_Port node_out) {
|
||||
return static_cast<TF_DataType>(
|
||||
oper_out.oper->node.output_type(oper_out.index));
|
||||
node_out.node->node.output_type(node_out.index));
|
||||
}
|
||||
|
||||
int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
|
||||
int TF_NodeOutputListLength(TF_Node* node, const char* arg_name,
|
||||
TF_Status* status) {
|
||||
NameRangeMap name_ranges;
|
||||
status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(),
|
||||
status->status = NameRangesForNode(node->node.def(), node->node.op_def(),
|
||||
nullptr, &name_ranges);
|
||||
if (!status->status.ok()) return -1;
|
||||
auto iter = name_ranges.find(arg_name);
|
||||
@ -962,18 +956,16 @@ int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
|
||||
return iter->second.second - iter->second.first;
|
||||
}
|
||||
|
||||
int TF_OperationNumInputs(TF_Operation* oper) {
|
||||
return oper->node.num_inputs();
|
||||
int TF_NodeNumInputs(TF_Node* node) { return node->node.num_inputs(); }
|
||||
|
||||
TF_DataType TF_NodeInputType(TF_Port node_in) {
|
||||
return static_cast<TF_DataType>(node_in.node->node.input_type(node_in.index));
|
||||
}
|
||||
|
||||
TF_DataType TF_OperationInputType(TF_Port oper_in) {
|
||||
return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index));
|
||||
}
|
||||
|
||||
int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
|
||||
int TF_NodeInputListLength(TF_Node* node, const char* arg_name,
|
||||
TF_Status* status) {
|
||||
NameRangeMap name_ranges;
|
||||
status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(),
|
||||
status->status = NameRangesForNode(node->node.def(), node->node.op_def(),
|
||||
&name_ranges, nullptr);
|
||||
if (!status->status.ok()) return -1;
|
||||
auto iter = name_ranges.find(arg_name);
|
||||
@ -985,32 +977,32 @@ int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
|
||||
return iter->second.second - iter->second.first;
|
||||
}
|
||||
|
||||
TF_Port TF_OperationInput(TF_Port oper_in) {
|
||||
for (const auto* edge : oper_in.oper->node.in_edges()) {
|
||||
if (edge->dst_input() == oper_in.index) {
|
||||
return {ToOperation(edge->src()), edge->src_output()};
|
||||
TF_Port TF_NodeInput(TF_Port node_in) {
|
||||
for (const auto* edge : node_in.node->node.in_edges()) {
|
||||
if (edge->dst_input() == node_in.index) {
|
||||
return {ToNode(edge->src()), edge->src_output()};
|
||||
}
|
||||
}
|
||||
return {nullptr, -1};
|
||||
}
|
||||
|
||||
int TF_OperationOutputNumConsumers(TF_Port oper_out) {
|
||||
int TF_NodeOutputNumConsumers(TF_Port node_out) {
|
||||
int count = 0;
|
||||
for (const auto* edge : oper_out.oper->node.out_edges()) {
|
||||
if (edge->src_output() == oper_out.index) {
|
||||
for (const auto* edge : node_out.node->node.out_edges()) {
|
||||
if (edge->src_output() == node_out.index) {
|
||||
++count;
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
int TF_OperationOutputConsumers(TF_Port oper_out, TF_Port* consumers,
|
||||
int TF_NodeOutputConsumers(TF_Port node_out, TF_Port* consumers,
|
||||
int max_consumers) {
|
||||
int count = 0;
|
||||
for (const auto* edge : oper_out.oper->node.out_edges()) {
|
||||
if (edge->src_output() == oper_out.index) {
|
||||
for (const auto* edge : node_out.node->node.out_edges()) {
|
||||
if (edge->src_output() == node_out.index) {
|
||||
if (count < max_consumers) {
|
||||
consumers[count] = {ToOperation(edge->dst()), edge->dst_input()};
|
||||
consumers[count] = {ToNode(edge->dst()), edge->dst_input()};
|
||||
}
|
||||
++count;
|
||||
}
|
||||
@ -1018,9 +1010,9 @@ int TF_OperationOutputConsumers(TF_Port oper_out, TF_Port* consumers,
|
||||
return count;
|
||||
}
|
||||
|
||||
int TF_OperationNumControlInputs(TF_Operation* oper) {
|
||||
int TF_NodeNumControlInputs(TF_Node* node) {
|
||||
int count = 0;
|
||||
for (const auto* edge : oper->node.in_edges()) {
|
||||
for (const auto* edge : node->node.in_edges()) {
|
||||
if (edge->IsControlEdge()) {
|
||||
++count;
|
||||
}
|
||||
@ -1028,14 +1020,13 @@ int TF_OperationNumControlInputs(TF_Operation* oper) {
|
||||
return count;
|
||||
}
|
||||
|
||||
int TF_OperationGetControlInputs(TF_Operation* oper,
|
||||
TF_Operation** control_inputs,
|
||||
int TF_NodeGetControlInputs(TF_Node* node, TF_Node** control_inputs,
|
||||
int max_control_inputs) {
|
||||
int count = 0;
|
||||
for (const auto* edge : oper->node.in_edges()) {
|
||||
for (const auto* edge : node->node.in_edges()) {
|
||||
if (edge->IsControlEdge()) {
|
||||
if (count < max_control_inputs) {
|
||||
control_inputs[count] = ToOperation(edge->src());
|
||||
control_inputs[count] = ToNode(edge->src());
|
||||
}
|
||||
++count;
|
||||
}
|
||||
@ -1043,9 +1034,9 @@ int TF_OperationGetControlInputs(TF_Operation* oper,
|
||||
return count;
|
||||
}
|
||||
|
||||
int TF_OperationNumControlOutputs(TF_Operation* oper) {
|
||||
int TF_NodeNumControlOutputs(TF_Node* node) {
|
||||
int count = 0;
|
||||
for (const auto* edge : oper->node.out_edges()) {
|
||||
for (const auto* edge : node->node.out_edges()) {
|
||||
if (edge->IsControlEdge()) {
|
||||
++count;
|
||||
}
|
||||
@ -1053,14 +1044,13 @@ int TF_OperationNumControlOutputs(TF_Operation* oper) {
|
||||
return count;
|
||||
}
|
||||
|
||||
int TF_OperationGetControlOutputs(TF_Operation* oper,
|
||||
TF_Operation** control_outputs,
|
||||
int TF_NodeGetControlOutputs(TF_Node* node, TF_Node** control_outputs,
|
||||
int max_control_outputs) {
|
||||
int count = 0;
|
||||
for (const auto* edge : oper->node.out_edges()) {
|
||||
for (const auto* edge : node->node.out_edges()) {
|
||||
if (edge->IsControlEdge()) {
|
||||
if (count < max_control_outputs) {
|
||||
control_outputs[count] = ToOperation(edge->dst());
|
||||
control_outputs[count] = ToNode(edge->dst());
|
||||
}
|
||||
++count;
|
||||
}
|
||||
@ -1068,20 +1058,19 @@ int TF_OperationGetControlOutputs(TF_Operation* oper,
|
||||
return count;
|
||||
}
|
||||
|
||||
void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
|
||||
TF_Buffer* output_attr_value,
|
||||
TF_Status* status) {
|
||||
void TF_NodeGetAttrValueProto(TF_Node* node, const char* attr_name,
|
||||
TF_Buffer* output_attr_value, TF_Status* status) {
|
||||
if (output_attr_value->data != nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Passing non-empty output_attr_value is invalid.");
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& attr_map = oper->node.def().attr();
|
||||
const auto& attr_map = node->node.def().attr();
|
||||
auto iter = attr_map.find(attr_name);
|
||||
if (iter == attr_map.end()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Operation has no attr named '", attr_name, "'.");
|
||||
"Node has no attr named '", attr_name, "'.");
|
||||
return;
|
||||
}
|
||||
|
||||
@ -1097,7 +1086,7 @@ void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
|
||||
status->status = Status::OK();
|
||||
}
|
||||
|
||||
void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
|
||||
void TF_NodeToNodeDef(TF_Node* node, TF_Buffer* output_node_def,
|
||||
TF_Status* status) {
|
||||
if (output_node_def->data != nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
@ -1105,7 +1094,7 @@ void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
|
||||
return;
|
||||
}
|
||||
|
||||
const NodeDef& def = oper->node.def();
|
||||
const NodeDef& def = node->node.def();
|
||||
const auto proto_size = def.ByteSize();
|
||||
void* str_buf = malloc(proto_size);
|
||||
def.SerializeToArray(str_buf, proto_size);
|
||||
@ -1129,17 +1118,17 @@ void TF_DeleteGraph(TF_Graph* g) {
|
||||
if (del) delete g;
|
||||
}
|
||||
|
||||
TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) {
|
||||
TF_Node* TF_GraphNodeByName(TF_Graph* graph, const char* node_name) {
|
||||
mutex_lock l(graph->mu);
|
||||
auto iter = graph->name_map.find(oper_name);
|
||||
auto iter = graph->name_map.find(node_name);
|
||||
if (iter == graph->name_map.end()) {
|
||||
return nullptr;
|
||||
} else {
|
||||
return ToOperation(iter->second);
|
||||
return ToNode(iter->second);
|
||||
}
|
||||
}
|
||||
|
||||
TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) {
|
||||
TF_Node* TF_GraphNextNode(TF_Graph* graph, size_t* pos) {
|
||||
if (*pos == 0) {
|
||||
// Advance past the first sentinal nodes in every graph (the source & sink).
|
||||
*pos += 2;
|
||||
@ -1154,7 +1143,7 @@ TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) {
|
||||
// FindNodeId() returns nullptr for nodes that have been deleted.
|
||||
// We aren't currently allowing nodes to be deleted, but it is safer
|
||||
// to still check.
|
||||
if (node != nullptr) return ToOperation(node);
|
||||
if (node != nullptr) return reinterpret_cast<TF_Node*>(node);
|
||||
*pos += 1;
|
||||
}
|
||||
|
||||
@ -1268,7 +1257,7 @@ void TF_SessionRun(TF_SessionWithGraph* session, const TF_Buffer* run_options,
|
||||
const TF_Port* inputs, TF_Tensor* const* input_values,
|
||||
int ninputs, const TF_Port* outputs,
|
||||
TF_Tensor** output_values, int noutputs,
|
||||
const TF_Operation* const* target_opers, int ntargets,
|
||||
const TF_Node* const* target_nodes, int ntargets,
|
||||
TF_Buffer* run_metadata, TF_Status* status) {
|
||||
// TODO(josh11b,mrry): Change Session to be able to use a Graph*
|
||||
// directly, instead of requiring us to serialize to a GraphDef and
|
||||
@ -1295,10 +1284,10 @@ void TF_SessionRun(TF_SessionWithGraph* session, const TF_Buffer* run_options,
|
||||
output_names[i] = PortName(outputs[i]);
|
||||
}
|
||||
|
||||
// Convert from TF_Operation* to string names.
|
||||
// Convert from TF_Node* to string names.
|
||||
std::vector<tensorflow::string> target_names(ntargets);
|
||||
for (int i = 0; i < ntargets; ++i) {
|
||||
target_names[i] = target_opers[i]->node.name();
|
||||
target_names[i] = target_nodes[i]->node.name();
|
||||
}
|
||||
|
||||
// Actually run.
|
||||
@ -1309,7 +1298,7 @@ void TF_SessionRun(TF_SessionWithGraph* session, const TF_Buffer* run_options,
|
||||
|
||||
void TF_SessionPRunSetup(TF_SessionWithGraph* session, const TF_Port* inputs,
|
||||
int ninputs, const TF_Port* outputs, int noutputs,
|
||||
const TF_Operation* const* target_opers, int ntargets,
|
||||
const TF_Node* const* target_nodes, int ntargets,
|
||||
const char** handle, TF_Status* status) {
|
||||
if (!ExtendSessionGraphHelper(session, status)) {
|
||||
return;
|
||||
@ -1327,7 +1316,7 @@ void TF_SessionPRunSetup(TF_SessionWithGraph* session, const TF_Port* inputs,
|
||||
|
||||
std::vector<tensorflow::string> target_names(ntargets);
|
||||
for (int i = 0; i < ntargets; ++i) {
|
||||
target_names[i] = target_opers[i]->node.name();
|
||||
target_names[i] = target_nodes[i]->node.name();
|
||||
}
|
||||
|
||||
tensorflow::string new_handle;
|
||||
@ -1344,7 +1333,7 @@ void TF_SessionPRun(TF_SessionWithGraph* session, const char* handle,
|
||||
const TF_Port* inputs, TF_Tensor* const* input_values,
|
||||
int ninputs, const TF_Port* outputs,
|
||||
TF_Tensor** output_values, int noutputs,
|
||||
const TF_Operation* const* target_opers, int ntargets,
|
||||
const TF_Node* const* target_nodes, int ntargets,
|
||||
TF_Status* status) {
|
||||
// TODO(josh11b,mrry): Change Session to be able to use a Graph*
|
||||
// directly, instead of requiring us to serialize to a GraphDef and
|
||||
@ -1371,10 +1360,10 @@ void TF_SessionPRun(TF_SessionWithGraph* session, const char* handle,
|
||||
output_names[i] = PortName(outputs[i]);
|
||||
}
|
||||
|
||||
// Convert from TF_Operation* to string names.
|
||||
// Convert from TF_Node* to string names.
|
||||
std::vector<tensorflow::string> target_names(ntargets);
|
||||
for (int i = 0; i < ntargets; ++i) {
|
||||
target_names[i] = target_opers[i]->node.name();
|
||||
target_names[i] = target_nodes[i]->node.name();
|
||||
}
|
||||
|
||||
TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names,
|
||||
|
@ -247,31 +247,29 @@ extern TF_Graph* TF_NewGraph();
|
||||
// TFSessionWithGraph's are referencing it.
|
||||
extern void TF_DeleteGraph(TF_Graph*);
|
||||
|
||||
// Operation being built. The underlying graph must outlive this.
|
||||
typedef struct TF_OperationDescription TF_OperationDescription;
|
||||
// Node being built. The underlying graph must outlive this.
|
||||
typedef struct TF_NodeDescription TF_NodeDescription;
|
||||
|
||||
// Operation that has been added to the graph. Valid until the graph is
|
||||
// deleted -- in particular adding a new operation to the graph does not
|
||||
// invalidate old TF_Operation* pointers.
|
||||
typedef struct TF_Operation TF_Operation;
|
||||
// Node that has been added to the graph. Valid until the graph is
|
||||
// deleted -- in particular adding a new node to the graph does not
|
||||
// invalidate old TF_Node* pointers.
|
||||
typedef struct TF_Node TF_Node;
|
||||
|
||||
// Represents a specific input or output of an operation, e.g. to
|
||||
// specify the specific output to pass as an input to a new op.
|
||||
// Represents a specific input or output of a node, e.g. to specify the
|
||||
// specific output to pass as an input to an op.
|
||||
typedef struct TF_Port {
|
||||
TF_Operation* oper;
|
||||
int index; // Specifies the index of the input or output within oper.
|
||||
TF_Node* node;
|
||||
int index; // Specifies the index of the input or output within node.
|
||||
} TF_Port;
|
||||
|
||||
// Operation will only be added to *graph when TF_FinishOperation() is
|
||||
// called (assuming TF_FinishOperation() does not return an error).
|
||||
// *graph must not be deleted until after TF_FinishOperation() is
|
||||
// called.
|
||||
extern TF_OperationDescription* TF_NewOperation(TF_Graph* graph,
|
||||
const char* op_type,
|
||||
const char* oper_name);
|
||||
// Node will only be added to *graph when TF_FinishNode() is called
|
||||
// (assuming TF_FinishNode() does not return an error). *graph must
|
||||
// not be deleted until after TF_FinishNode() is called.
|
||||
extern TF_NodeDescription* TF_NewNode(TF_Graph* graph, const char* op_type,
|
||||
const char* node_name);
|
||||
|
||||
// Specify the device for `desc`. Defaults to empty, meaning unconstrained.
|
||||
extern void TF_SetDevice(TF_OperationDescription* desc, const char* device);
|
||||
extern void TF_SetDevice(TF_NodeDescription* desc, const char* device);
|
||||
|
||||
// The calls to TF_AddInput and TF_AddInputList must match (in number,
|
||||
// order, and type) the op declaration. For example, the "Concat" op
|
||||
@ -287,82 +285,74 @@ extern void TF_SetDevice(TF_OperationDescription* desc, const char* device);
|
||||
// single tensor), and TF_AddInputList() for the second input (since
|
||||
// it takes a list, even if you were to pass a list with a single
|
||||
// tensor), as in:
|
||||
// TF_OperationDescription* desc = TF_NewOperation(graph, "Concat", "c");
|
||||
// TF_NodeDescription* desc = TF_NewNode(graph, "Concat", "c");
|
||||
// TF_Port concat_dim_input = {...};
|
||||
// TF_AddInput(desc, concat_dim_input);
|
||||
// TF_Port values_inputs[5] = {{...}, ..., {...}};
|
||||
// TF_AddInputList(desc, 5, values_inputs);
|
||||
|
||||
// For inputs that take a single tensor.
|
||||
extern void TF_AddInput(TF_OperationDescription* desc, TF_Port input);
|
||||
extern void TF_AddInput(TF_NodeDescription* desc, TF_Port input);
|
||||
|
||||
// For inputs that take a list of tensors.
|
||||
// inputs must point to TF_Port[num_inputs].
|
||||
extern void TF_AddInputList(TF_OperationDescription* desc,
|
||||
const TF_Port* inputs, int num_inputs);
|
||||
extern void TF_AddInputList(TF_NodeDescription* desc, const TF_Port* inputs,
|
||||
int num_inputs);
|
||||
|
||||
// Call once per control input to `desc`.
|
||||
extern void TF_AddControlInput(TF_OperationDescription* desc,
|
||||
TF_Operation* input);
|
||||
extern void TF_AddControlInput(TF_NodeDescription* desc, TF_Node* input);
|
||||
|
||||
// Call some TF_SetAttr*() function for every attr that is not
|
||||
// inferred from an input and doesn't have a default value you wish to
|
||||
// keep.
|
||||
|
||||
// `value` must point to a string of length `length` bytes.
|
||||
extern void TF_SetAttrString(TF_OperationDescription* desc,
|
||||
const char* attr_name, const void* value,
|
||||
int length);
|
||||
extern void TF_SetAttrString(TF_NodeDescription* desc, const char* attr_name,
|
||||
const void* value, int length);
|
||||
// `values` and `lengths` both must have lengths `num_values`.
|
||||
// `values[i]` must point to a string of length `lengths[i]` bytes.
|
||||
extern void TF_SetAttrStringList(TF_OperationDescription* desc,
|
||||
extern void TF_SetAttrStringList(TF_NodeDescription* desc,
|
||||
const char* attr_name,
|
||||
const void* const* values, const int* lengths,
|
||||
int num_values);
|
||||
extern void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
|
||||
extern void TF_SetAttrInt(TF_NodeDescription* desc, const char* attr_name,
|
||||
int64_t value);
|
||||
extern void TF_SetAttrIntList(TF_OperationDescription* desc,
|
||||
const char* attr_name, const int64_t* values,
|
||||
int num_values);
|
||||
extern void TF_SetAttrFloat(TF_OperationDescription* desc,
|
||||
const char* attr_name, float value);
|
||||
extern void TF_SetAttrFloatList(TF_OperationDescription* desc,
|
||||
const char* attr_name, const float* values,
|
||||
int num_values);
|
||||
extern void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
|
||||
extern void TF_SetAttrIntList(TF_NodeDescription* desc, const char* attr_name,
|
||||
const int64_t* values, int num_values);
|
||||
extern void TF_SetAttrFloat(TF_NodeDescription* desc, const char* attr_name,
|
||||
float value);
|
||||
extern void TF_SetAttrFloatList(TF_NodeDescription* desc, const char* attr_name,
|
||||
const float* values, int num_values);
|
||||
extern void TF_SetAttrBool(TF_NodeDescription* desc, const char* attr_name,
|
||||
unsigned char value);
|
||||
extern void TF_SetAttrBoolList(TF_OperationDescription* desc,
|
||||
const char* attr_name,
|
||||
extern void TF_SetAttrBoolList(TF_NodeDescription* desc, const char* attr_name,
|
||||
const unsigned char* values, int num_values);
|
||||
extern void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
|
||||
extern void TF_SetAttrType(TF_NodeDescription* desc, const char* attr_name,
|
||||
TF_DataType value);
|
||||
extern void TF_SetAttrTypeList(TF_OperationDescription* desc,
|
||||
const char* attr_name, const TF_DataType* values,
|
||||
int num_values);
|
||||
extern void TF_SetAttrTypeList(TF_NodeDescription* desc, const char* attr_name,
|
||||
const TF_DataType* values, int num_values);
|
||||
|
||||
// Set `num_dims` to -1 to represent "unknown rank". Otherwise,
|
||||
// `dims` points to an array of length `num_dims`. `dims[i]` must be
|
||||
// >= -1, with -1 meaning "unknown dimension".
|
||||
extern void TF_SetAttrShape(TF_OperationDescription* desc,
|
||||
const char* attr_name, const int64_t* dims,
|
||||
int num_dims);
|
||||
extern void TF_SetAttrShape(TF_NodeDescription* desc, const char* attr_name,
|
||||
const int64_t* dims, int num_dims);
|
||||
// `dims` and `num_dims` must point to arrays of length `num_shapes`.
|
||||
// Set `num_dims[i]` to -1 to represent "unknown rank". Otherwise,
|
||||
// `dims[i]` points to an array of length `num_dims[i]`. `dims[i][j]`
|
||||
// must be >= -1, with -1 meaning "unknown dimension".
|
||||
extern void TF_SetAttrShapeList(TF_OperationDescription* desc,
|
||||
const char* attr_name,
|
||||
extern void TF_SetAttrShapeList(TF_NodeDescription* desc, const char* attr_name,
|
||||
const int64_t* const* dims, const int* num_dims,
|
||||
int num_shapes);
|
||||
// `proto` must point to an array of `proto_len` bytes representing a
|
||||
// binary-serialized TensorShapeProto.
|
||||
extern void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
|
||||
extern void TF_SetAttrTensorShapeProto(TF_NodeDescription* desc,
|
||||
const char* attr_name, void* proto,
|
||||
int proto_len, TF_Status* status);
|
||||
// `protos` and `proto_lens` must point to arrays of length `num_shapes`.
|
||||
// `protos[i]` must point to an array of `proto_lens[i]` bytes
|
||||
// representing a binary-serialized TensorShapeProto.
|
||||
extern void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
|
||||
extern void TF_SetAttrTensorShapeProtoList(TF_NodeDescription* desc,
|
||||
const char* attr_name,
|
||||
const void* const* protos,
|
||||
const int* proto_lens,
|
||||
@ -370,12 +360,11 @@ extern void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
|
||||
|
||||
// This functions takes ownership of *value (the
|
||||
// implementation will eventually call TF_DeleteTensor).
|
||||
extern void TF_SetAttrTensor(TF_OperationDescription* desc,
|
||||
const char* attr_name, TF_Tensor* value,
|
||||
TF_Status* status);
|
||||
extern void TF_SetAttrTensor(TF_NodeDescription* desc, const char* attr_name,
|
||||
TF_Tensor* value, TF_Status* status);
|
||||
// This functions takes ownership of values[0]..values[num_values-1] (the
|
||||
// implementation will eventually call TF_DeleteTensor on each).
|
||||
extern void TF_SetAttrTensorList(TF_OperationDescription* desc,
|
||||
extern void TF_SetAttrTensorList(TF_NodeDescription* desc,
|
||||
const char* attr_name,
|
||||
TF_Tensor* const* values, int num_values,
|
||||
TF_Status* status);
|
||||
@ -383,108 +372,100 @@ extern void TF_SetAttrTensorList(TF_OperationDescription* desc,
|
||||
// `proto` should point to a sequence of bytes of length `proto_len`
|
||||
// representing a binary serialization of an AttrValue protocol
|
||||
// buffer.
|
||||
extern void TF_SetAttrToAttrValueProto(TF_OperationDescription* desc,
|
||||
extern void TF_SetAttrToAttrValueProto(TF_NodeDescription* desc,
|
||||
const char* attr_name, const void* proto,
|
||||
size_t proto_len, TF_Status* status);
|
||||
|
||||
// If this function succeeds:
|
||||
// * *status is set to an OK value,
|
||||
// * a TF_Operation is added to the graph,
|
||||
// * a non-null value pointing to the added operation is returned --
|
||||
// * a TF_Node is added to the graph,
|
||||
// * a non-null value pointing to the added node is returned --
|
||||
// this value is valid until the underlying graph is deleted.
|
||||
// Otherwise:
|
||||
// * *status is set to a non-OK value,
|
||||
// * the graph is not modified,
|
||||
// * a null value is returned.
|
||||
// In either case, it deletes `desc`.
|
||||
extern TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
|
||||
extern TF_Node* TF_FinishNode(TF_NodeDescription* desc, TF_Status* status);
|
||||
|
||||
// TF_Node functions. Nodes are immutable once created, so these are all
|
||||
// query functions.
|
||||
|
||||
extern const char* TF_NodeName(TF_Node* node);
|
||||
extern const char* TF_NodeOpType(TF_Node* node);
|
||||
extern const char* TF_NodeDevice(TF_Node* node);
|
||||
|
||||
extern int TF_NodeNumOutputs(TF_Node* node);
|
||||
extern TF_DataType TF_NodeOutputType(TF_Port node_out);
|
||||
extern int TF_NodeOutputListLength(TF_Node* node, const char* arg_name,
|
||||
TF_Status* status);
|
||||
|
||||
// TF_Operation functions. Operations are immutable once created, so
|
||||
// these are all query functions.
|
||||
|
||||
extern const char* TF_OperationName(TF_Operation* oper);
|
||||
extern const char* TF_OperationOpType(TF_Operation* oper);
|
||||
extern const char* TF_OperationDevice(TF_Operation* oper);
|
||||
|
||||
extern int TF_OperationNumOutputs(TF_Operation* oper);
|
||||
extern TF_DataType TF_OperationOutputType(TF_Port oper_out);
|
||||
extern int TF_OperationOutputListLength(TF_Operation* oper,
|
||||
const char* arg_name,
|
||||
TF_Status* status);
|
||||
|
||||
extern int TF_OperationNumInputs(TF_Operation* oper);
|
||||
extern TF_DataType TF_OperationInputType(TF_Port oper_in);
|
||||
extern int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
|
||||
extern int TF_NodeNumInputs(TF_Node* node);
|
||||
extern TF_DataType TF_NodeInputType(TF_Port node_in);
|
||||
extern int TF_NodeInputListLength(TF_Node* node, const char* arg_name,
|
||||
TF_Status* status);
|
||||
|
||||
// In this code:
|
||||
// TF_Port producer = TF_OperationInput(consumer);
|
||||
// There is an edge from producer.oper's output (given by
|
||||
// producer.index) to consumer.oper's input (given by consumer.index).
|
||||
extern TF_Port TF_OperationInput(TF_Port oper_in);
|
||||
// TF_Port producer = TF_NodeInput(consumer);
|
||||
// There is an edge from producer.node's output (given by
|
||||
// producer.index) to consumer.node's input (given by consumer.index).
|
||||
extern TF_Port TF_NodeInput(TF_Port node_in);
|
||||
|
||||
// Get the number of current consumers of a specific output of an
|
||||
// operation. Note that this number can change when new operations
|
||||
// are added to the graph.
|
||||
extern int TF_OperationOutputNumConsumers(TF_Port oper_out);
|
||||
// Get the number of current consumers of a node's output. Note that
|
||||
// this number can change when new nodes are added to the graph.
|
||||
extern int TF_NodeOutputNumConsumers(TF_Port node_out);
|
||||
|
||||
// Get list of all current consumers of a specific output of an
|
||||
// operation. `consumers` must point to an array of length at least
|
||||
// `max_consumers` (ideally set to
|
||||
// TF_OperationOutputNumConsumers(oper_out)). Beware that a concurrent
|
||||
// modification of the graph can increase the number of consumers of
|
||||
// an operation. Returns the number of output consumers (should match
|
||||
// TF_OperationOutputNumConsumers(oper_out)).
|
||||
extern int TF_OperationOutputConsumers(TF_Port oper_out, TF_Port* consumers,
|
||||
// Get list of all current consumers of a node's output. consumers
|
||||
// must point to an array of length at least max_consumers (ideally
|
||||
// set to TF_NodeOutputNumConsumer(node_out)). Beware that a
|
||||
// concurrent modification of the graph can increase the number of
|
||||
// consumers of a node. Returns the number of output consumers
|
||||
// (should match TF_NodeOutputNumConsumers(node_out)).
|
||||
extern int TF_NodeOutputConsumers(TF_Port node_out, TF_Port* consumers,
|
||||
int max_consumers);
|
||||
|
||||
// Get the number of control inputs to an operation.
|
||||
extern int TF_OperationNumControlInputs(TF_Operation* oper);
|
||||
// Get the number of control inputs to a node.
|
||||
extern int TF_NodeNumControlInputs(TF_Node* node);
|
||||
|
||||
// Get list of all control inputs to an operation. `control_inputs` must
|
||||
// point to an array of length `max_control_inputs` (ideally set to
|
||||
// TF_OperationNumControlInputs(oper)). Returns the number of control
|
||||
// inputs (should match TF_OperationNumControlInputs(oper)).
|
||||
extern int TF_OperationGetControlInputs(TF_Operation* oper,
|
||||
TF_Operation** control_inputs,
|
||||
// Get list of all control inputs to a node. control_inputs must
|
||||
// point to an array of length max_control_inputs (ideally set to
|
||||
// TF_NodeNumControlInputs(node)). Returns the number of control
|
||||
// inputs (should match TF_NodeNumControlInputs(node)).
|
||||
extern int TF_NodeGetControlInputs(TF_Node* node, TF_Node** control_inputs,
|
||||
int max_control_inputs);
|
||||
|
||||
// Get the number of operations that have `*oper` as a control input.
|
||||
// Note that this number can change when new operations are added to
|
||||
// the graph.
|
||||
extern int TF_OperationNumControlOutputs(TF_Operation* oper);
|
||||
// Get the number of nodes that have *node as a control inputs.
|
||||
// Note that this number can change when new nodes are added to the
|
||||
// graph.
|
||||
extern int TF_NodeNumControlOutputs(TF_Node* node);
|
||||
|
||||
// Get the list of operations that have `*oper` as a control input.
|
||||
// `control_outputs` must point to an array of length at least
|
||||
// `max_control_outputs` (ideally set to
|
||||
// TF_OperationNumControlOutputs(oper)). Beware that a concurrent
|
||||
// Get the list of nodes that have *node as a control input.
|
||||
// control_outputs must point to an array of length at least
|
||||
// max_control_outputs (ideally set to
|
||||
// TF_NodeNumControlOutputs(node)). Beware that a concurrent
|
||||
// modification of the graph can increase the number of control
|
||||
// outputs. Returns the number of control outputs (should match
|
||||
// TF_OperationNumControlOutputs(oper)).
|
||||
extern int TF_OperationGetControlOutputs(TF_Operation* oper,
|
||||
TF_Operation** control_outputs,
|
||||
// TF_NodeNumControlOutputs(node)).
|
||||
extern int TF_NodeGetControlOutputs(TF_Node* node, TF_Node** control_outputs,
|
||||
int max_control_outputs);
|
||||
|
||||
// Sets `output_attr_value` to the binary-serialized AttrValue proto
|
||||
// representation of the value of the `attr_name` attr of `oper`.
|
||||
extern void TF_OperationGetAttrValueProto(TF_Operation* oper,
|
||||
const char* attr_name,
|
||||
// representation of the value of the `attr_name` attr of `node`.
|
||||
extern void TF_NodeGetAttrValueProto(TF_Node* node, const char* attr_name,
|
||||
TF_Buffer* output_attr_value,
|
||||
TF_Status* status);
|
||||
|
||||
// Returns the operation in the graph with `oper_name`. Returns nullptr if
|
||||
// no operation found.
|
||||
extern TF_Operation* TF_GraphOperationByName(TF_Graph* graph,
|
||||
const char* oper_name);
|
||||
// Returns the node in the graph with `node_name`. Returns nullptr if
|
||||
// no node found.
|
||||
extern TF_Node* TF_GraphNodeByName(TF_Graph* graph, const char* node_name);
|
||||
|
||||
// Iterate through the operations of a graph. To use:
|
||||
// Iterate through the nodes of a graph. To use:
|
||||
// size_t pos = 0;
|
||||
// TF_Operation* oper;
|
||||
// while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) {
|
||||
// DoSomethingWithOperation(oper);
|
||||
// TF_Node* node;
|
||||
// while ((node = TF_GraphNextNode(graph, &pos)) != nullptr) {
|
||||
// DoSomethingWithNode(node);
|
||||
// }
|
||||
extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos);
|
||||
extern TF_Node* TF_GraphNextNode(TF_Graph* graph, size_t* pos);
|
||||
|
||||
// Note: The following two functions may fail on very large protos in the
|
||||
// future.
|
||||
@ -492,19 +473,18 @@ extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos);
|
||||
extern void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
|
||||
TF_Status* status);
|
||||
|
||||
extern void TF_OperationToNodeDef(TF_Operation* oper,
|
||||
TF_Buffer* output_node_def,
|
||||
extern void TF_NodeToNodeDef(TF_Node* node, TF_Buffer* output_node_def,
|
||||
TF_Status* status);
|
||||
|
||||
// TODO(josh11b): Query attrs for an operation.
|
||||
// TODO(josh11b): Query attrs for a Node.
|
||||
|
||||
// TODO(cwhipkey): Query shape for operation outputs.
|
||||
// TODO(cwhipkey): Query shape for node outputs.
|
||||
|
||||
// TODO(josh11b,mrry): Import GraphDef into TF_Graph.
|
||||
|
||||
// TODO(andydavis): Function to add gradients to a graph.
|
||||
|
||||
// TODO(josh11b): Register OpDef, available to all operations added
|
||||
// TODO(josh11b): Register OpDef, available to all nodes added
|
||||
// to this graph.
|
||||
|
||||
// The following two may both benefit from a subgraph-definition API
|
||||
@ -550,8 +530,8 @@ extern void TF_SessionRun(TF_SessionWithGraph* session,
|
||||
// Output tensors
|
||||
const TF_Port* outputs, TF_Tensor** output_values,
|
||||
int noutputs,
|
||||
// Target operations
|
||||
const TF_Operation* const* target_opers, int ntargets,
|
||||
// Target nodes
|
||||
const TF_Node* const* target_nodes, int ntargets,
|
||||
// RunMetadata
|
||||
TF_Buffer* run_metadata,
|
||||
// Output status
|
||||
@ -563,8 +543,8 @@ extern void TF_SessionPRunSetup(TF_SessionWithGraph*,
|
||||
const TF_Port* inputs, int ninputs,
|
||||
// Output names
|
||||
const TF_Port* outputs, int noutputs,
|
||||
// Target operations
|
||||
const TF_Operation* const* target_opers,
|
||||
// Target nodes
|
||||
const TF_Node* const* target_nodes,
|
||||
int ntargets,
|
||||
// Output handle
|
||||
const char** handle,
|
||||
@ -579,9 +559,8 @@ extern void TF_SessionPRun(TF_SessionWithGraph*, const char* handle,
|
||||
// Output tensors
|
||||
const TF_Port* outputs, TF_Tensor** output_values,
|
||||
int noutputs,
|
||||
// Target operations
|
||||
const TF_Operation* const* target_opers,
|
||||
int ntargets,
|
||||
// Target nodes
|
||||
const TF_Node* const* target_nodes, int ntargets,
|
||||
// Output status
|
||||
TF_Status*);
|
||||
|
||||
@ -643,9 +622,10 @@ extern void TF_Run(TF_Session*,
|
||||
// Input tensors
|
||||
const char** input_names, TF_Tensor** inputs, int ninputs,
|
||||
// Output tensors
|
||||
const char** output_names, TF_Tensor** outputs, int noutputs,
|
||||
// Target operations
|
||||
const char** target_oper_names, int ntargets,
|
||||
const char** output_tensor_names, TF_Tensor** outputs,
|
||||
int noutputs,
|
||||
// Target nodes
|
||||
const char** target_node_names, int ntargets,
|
||||
// RunMetadata
|
||||
TF_Buffer* run_metadata,
|
||||
// Output status
|
||||
@ -663,9 +643,9 @@ extern void TF_PRunSetup(TF_Session*,
|
||||
// Input names
|
||||
const char** input_names, int ninputs,
|
||||
// Output names
|
||||
const char** output_names, int noutputs,
|
||||
// Target operations
|
||||
const char** target_oper_names, int ntargets,
|
||||
const char** output_tensor_names, int noutputs,
|
||||
// Target nodes
|
||||
const char** target_node_names, int ntargets,
|
||||
// Output handle
|
||||
const char** handle,
|
||||
// Output status
|
||||
@ -678,10 +658,10 @@ extern void TF_PRun(TF_Session*, const char* handle,
|
||||
// Input tensors
|
||||
const char** input_names, TF_Tensor** inputs, int ninputs,
|
||||
// Output tensors
|
||||
const char** output_names, TF_Tensor** outputs,
|
||||
const char** output_tensor_names, TF_Tensor** outputs,
|
||||
int noutputs,
|
||||
// Target operations
|
||||
const char** target_oper_names, int ntargets,
|
||||
// Target nodes
|
||||
const char** target_node_names, int ntargets,
|
||||
// Output status
|
||||
TF_Status*);
|
||||
|
||||
|
@ -202,33 +202,32 @@ static TF_Tensor* Int32Tensor(int32 v) {
|
||||
&Int32Deallocator, nullptr);
|
||||
}
|
||||
|
||||
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", "feed");
|
||||
TF_Node* Placeholder(TF_Graph* graph, TF_Status* s) {
|
||||
TF_NodeDescription* desc = TF_NewNode(graph, "Placeholder", "feed");
|
||||
TF_SetAttrType(desc, "dtype", TF_INT32);
|
||||
return TF_FinishOperation(desc, s);
|
||||
return TF_FinishNode(desc, s);
|
||||
}
|
||||
|
||||
TF_Operation* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "Const", "scalar");
|
||||
TF_Node* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s) {
|
||||
TF_NodeDescription* desc = TF_NewNode(graph, "Const", "scalar");
|
||||
TF_SetAttrTensor(desc, "value", Int32Tensor(v), s);
|
||||
if (TF_GetCode(s) != TF_OK) return nullptr;
|
||||
TF_SetAttrType(desc, "dtype", TF_INT32);
|
||||
return TF_FinishOperation(desc, s);
|
||||
return TF_FinishNode(desc, s);
|
||||
}
|
||||
|
||||
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
|
||||
TF_Status* s) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", "add");
|
||||
TF_Node* Add(TF_Node* l, TF_Node* r, TF_Graph* graph, TF_Status* s) {
|
||||
TF_NodeDescription* desc = TF_NewNode(graph, "AddN", "add");
|
||||
TF_Port add_inputs[2] = {{l, 0}, {r, 0}};
|
||||
TF_AddInputList(desc, add_inputs, 2);
|
||||
return TF_FinishOperation(desc, s);
|
||||
return TF_FinishNode(desc, s);
|
||||
}
|
||||
|
||||
TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", "neg");
|
||||
TF_Node* Neg(TF_Node* n, TF_Graph* graph, TF_Status* s) {
|
||||
TF_NodeDescription* desc = TF_NewNode(graph, "Neg", "neg");
|
||||
TF_Port neg_input = {n, 0};
|
||||
TF_AddInput(desc, neg_input);
|
||||
return TF_FinishOperation(desc, s);
|
||||
return TF_FinishNode(desc, s);
|
||||
}
|
||||
|
||||
bool IsPlaceholder(const NodeDef& node_def) {
|
||||
@ -319,10 +318,10 @@ bool GetGraphDef(TF_Graph* graph, GraphDef* graph_def) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool GetNodeDef(TF_Operation* oper, NodeDef* node_def) {
|
||||
bool GetNodeDef(TF_Node* node, NodeDef* node_def) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_Buffer* buffer = TF_NewBuffer();
|
||||
TF_OperationToNodeDef(oper, buffer, s);
|
||||
TF_NodeToNodeDef(node, buffer, s);
|
||||
bool ret = TF_GetCode(s) == TF_OK;
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
if (ret) ret = node_def->ParseFromArray(buffer->data, buffer->length);
|
||||
@ -331,10 +330,10 @@ bool GetNodeDef(TF_Operation* oper, NodeDef* node_def) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool GetAttrValue(TF_Operation* oper, const char* attr_name,
|
||||
bool GetAttrValue(TF_Node* node, const char* attr_name,
|
||||
tensorflow::AttrValue* attr_value, TF_Status* s) {
|
||||
TF_Buffer* buffer = TF_NewBuffer();
|
||||
TF_OperationGetAttrValueProto(oper, attr_name, buffer, s);
|
||||
TF_NodeGetAttrValueProto(node, attr_name, buffer, s);
|
||||
bool ret = TF_GetCode(s) == TF_OK;
|
||||
if (ret) ret = attr_value->ParseFromArray(buffer->data, buffer->length);
|
||||
TF_DeleteBuffer(buffer);
|
||||
@ -345,83 +344,82 @@ TEST(CAPI, Graph) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_Graph* graph = TF_NewGraph();
|
||||
|
||||
// Make a placeholder oper.
|
||||
TF_Operation* feed = Placeholder(graph, s);
|
||||
// Make a placeholder node.
|
||||
TF_Node* feed = Placeholder(graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Test TF_Operation*() query functions.
|
||||
EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
|
||||
EXPECT_EQ(string("Placeholder"), string(TF_OperationOpType(feed)));
|
||||
EXPECT_EQ(string(""), string(TF_OperationDevice(feed)));
|
||||
EXPECT_EQ(1, TF_OperationNumOutputs(feed));
|
||||
EXPECT_EQ(TF_INT32, TF_OperationOutputType(TF_Port{feed, 0}));
|
||||
EXPECT_EQ(1, TF_OperationOutputListLength(feed, "output", s));
|
||||
// Test TF_Node*() query functions.
|
||||
EXPECT_EQ(string("feed"), string(TF_NodeName(feed)));
|
||||
EXPECT_EQ(string("Placeholder"), string(TF_NodeOpType(feed)));
|
||||
EXPECT_EQ(string(""), string(TF_NodeDevice(feed)));
|
||||
EXPECT_EQ(1, TF_NodeNumOutputs(feed));
|
||||
EXPECT_EQ(TF_INT32, TF_NodeOutputType(TF_Port{feed, 0}));
|
||||
EXPECT_EQ(1, TF_NodeOutputListLength(feed, "output", s));
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
EXPECT_EQ(0, TF_OperationNumInputs(feed));
|
||||
EXPECT_EQ(0, TF_OperationOutputNumConsumers(TF_Port{feed, 0}));
|
||||
EXPECT_EQ(0, TF_OperationNumControlInputs(feed));
|
||||
EXPECT_EQ(0, TF_OperationNumControlOutputs(feed));
|
||||
EXPECT_EQ(0, TF_NodeNumInputs(feed));
|
||||
EXPECT_EQ(0, TF_NodeOutputNumConsumers(TF_Port{feed, 0}));
|
||||
EXPECT_EQ(0, TF_NodeNumControlInputs(feed));
|
||||
EXPECT_EQ(0, TF_NodeNumControlOutputs(feed));
|
||||
|
||||
tensorflow::AttrValue attr_value;
|
||||
ASSERT_TRUE(GetAttrValue(feed, "dtype", &attr_value, s)) << TF_Message(s);
|
||||
EXPECT_EQ(attr_value.type(), tensorflow::DT_INT32);
|
||||
|
||||
// Test not found errors in TF_Operation*() query functions.
|
||||
EXPECT_EQ(-1, TF_OperationOutputListLength(feed, "bogus", s));
|
||||
// Test not found errors in TF_Node*() query functions.
|
||||
EXPECT_EQ(-1, TF_NodeOutputListLength(feed, "bogus", s));
|
||||
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s));
|
||||
|
||||
ASSERT_FALSE(GetAttrValue(feed, "missing", &attr_value, s));
|
||||
EXPECT_EQ(string("Operation has no attr named 'missing'."),
|
||||
string(TF_Message(s)));
|
||||
EXPECT_EQ(string("Node has no attr named 'missing'."), string(TF_Message(s)));
|
||||
|
||||
// Make a constant oper with the scalar "3".
|
||||
TF_Operation* three = ScalarConst(3, graph, s);
|
||||
// Make a constant node with the scalar "3".
|
||||
TF_Node* three = ScalarConst(3, graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Add oper.
|
||||
TF_Operation* add = Add(feed, three, graph, s);
|
||||
// Add node.
|
||||
TF_Node* add = Add(feed, three, graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Test TF_Operation*() query functions.
|
||||
EXPECT_EQ(string("add"), string(TF_OperationName(add)));
|
||||
EXPECT_EQ(string("AddN"), string(TF_OperationOpType(add)));
|
||||
EXPECT_EQ(string(""), string(TF_OperationDevice(add)));
|
||||
EXPECT_EQ(1, TF_OperationNumOutputs(add));
|
||||
EXPECT_EQ(TF_INT32, TF_OperationOutputType(TF_Port{add, 0}));
|
||||
EXPECT_EQ(1, TF_OperationOutputListLength(add, "sum", s));
|
||||
// Test TF_Node*() query functions.
|
||||
EXPECT_EQ(string("add"), string(TF_NodeName(add)));
|
||||
EXPECT_EQ(string("AddN"), string(TF_NodeOpType(add)));
|
||||
EXPECT_EQ(string(""), string(TF_NodeDevice(add)));
|
||||
EXPECT_EQ(1, TF_NodeNumOutputs(add));
|
||||
EXPECT_EQ(TF_INT32, TF_NodeOutputType(TF_Port{add, 0}));
|
||||
EXPECT_EQ(1, TF_NodeOutputListLength(add, "sum", s));
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
EXPECT_EQ(2, TF_OperationNumInputs(add));
|
||||
EXPECT_EQ(2, TF_OperationInputListLength(add, "inputs", s));
|
||||
EXPECT_EQ(2, TF_NodeNumInputs(add));
|
||||
EXPECT_EQ(2, TF_NodeInputListLength(add, "inputs", s));
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
EXPECT_EQ(TF_INT32, TF_OperationInputType(TF_Port{add, 0}));
|
||||
EXPECT_EQ(TF_INT32, TF_OperationInputType(TF_Port{add, 1}));
|
||||
TF_Port add_in_0 = TF_OperationInput(TF_Port{add, 0});
|
||||
EXPECT_EQ(feed, add_in_0.oper);
|
||||
EXPECT_EQ(TF_INT32, TF_NodeInputType(TF_Port{add, 0}));
|
||||
EXPECT_EQ(TF_INT32, TF_NodeInputType(TF_Port{add, 1}));
|
||||
TF_Port add_in_0 = TF_NodeInput(TF_Port{add, 0});
|
||||
EXPECT_EQ(feed, add_in_0.node);
|
||||
EXPECT_EQ(0, add_in_0.index);
|
||||
TF_Port add_in_1 = TF_OperationInput(TF_Port{add, 1});
|
||||
EXPECT_EQ(three, add_in_1.oper);
|
||||
TF_Port add_in_1 = TF_NodeInput(TF_Port{add, 1});
|
||||
EXPECT_EQ(three, add_in_1.node);
|
||||
EXPECT_EQ(0, add_in_1.index);
|
||||
EXPECT_EQ(0, TF_OperationOutputNumConsumers(TF_Port{add, 0}));
|
||||
EXPECT_EQ(0, TF_OperationNumControlInputs(add));
|
||||
EXPECT_EQ(0, TF_OperationNumControlOutputs(add));
|
||||
EXPECT_EQ(0, TF_NodeOutputNumConsumers(TF_Port{add, 0}));
|
||||
EXPECT_EQ(0, TF_NodeNumControlInputs(add));
|
||||
EXPECT_EQ(0, TF_NodeNumControlOutputs(add));
|
||||
|
||||
ASSERT_TRUE(GetAttrValue(add, "T", &attr_value, s)) << TF_Message(s);
|
||||
EXPECT_EQ(attr_value.type(), tensorflow::DT_INT32);
|
||||
ASSERT_TRUE(GetAttrValue(add, "N", &attr_value, s)) << TF_Message(s);
|
||||
EXPECT_EQ(attr_value.i(), 2);
|
||||
|
||||
// Placeholder oper now has a consumer.
|
||||
ASSERT_EQ(1, TF_OperationOutputNumConsumers(TF_Port{feed, 0}));
|
||||
// Placeholder node now has a consumer.
|
||||
ASSERT_EQ(1, TF_NodeOutputNumConsumers(TF_Port{feed, 0}));
|
||||
TF_Port feed_port;
|
||||
EXPECT_EQ(1, TF_OperationOutputConsumers(TF_Port{feed, 0}, &feed_port, 1));
|
||||
EXPECT_EQ(add, feed_port.oper);
|
||||
EXPECT_EQ(1, TF_NodeOutputConsumers(TF_Port{feed, 0}, &feed_port, 1));
|
||||
EXPECT_EQ(add, feed_port.node);
|
||||
EXPECT_EQ(0, feed_port.index);
|
||||
|
||||
// The scalar const oper also has a consumer.
|
||||
ASSERT_EQ(1, TF_OperationOutputNumConsumers(TF_Port{three, 0}));
|
||||
// The scalar const node also has a consumer.
|
||||
ASSERT_EQ(1, TF_NodeOutputNumConsumers(TF_Port{three, 0}));
|
||||
TF_Port three_port;
|
||||
EXPECT_EQ(1, TF_OperationOutputConsumers(TF_Port{three, 0}, &three_port, 1));
|
||||
EXPECT_EQ(add, three_port.oper);
|
||||
EXPECT_EQ(1, TF_NodeOutputConsumers(TF_Port{three, 0}, &three_port, 1));
|
||||
EXPECT_EQ(add, three_port.node);
|
||||
EXPECT_EQ(1, three_port.index);
|
||||
|
||||
// Serialize to GraphDef.
|
||||
@ -450,8 +448,8 @@ TEST(CAPI, Graph) {
|
||||
EXPECT_TRUE(found_scalar_const);
|
||||
EXPECT_TRUE(found_add);
|
||||
|
||||
// Add another oper to the graph.
|
||||
TF_Operation* neg = Neg(add, graph, s);
|
||||
// Add another node to the graph.
|
||||
TF_Node* neg = Neg(add, graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Serialize to NodeDef.
|
||||
@ -471,13 +469,13 @@ TEST(CAPI, Graph) {
|
||||
EXPECT_EQ(ProtoDebugString(graph_def), ProtoDebugString(graph_def2));
|
||||
|
||||
// Look up some nodes by name.
|
||||
TF_Operation* neg2 = TF_GraphOperationByName(graph, "neg");
|
||||
TF_Node* neg2 = TF_GraphNodeByName(graph, "neg");
|
||||
EXPECT_TRUE(neg == neg2);
|
||||
NodeDef node_def2;
|
||||
ASSERT_TRUE(GetNodeDef(neg2, &node_def2));
|
||||
EXPECT_EQ(ProtoDebugString(node_def), ProtoDebugString(node_def2));
|
||||
|
||||
TF_Operation* feed2 = TF_GraphOperationByName(graph, "feed");
|
||||
TF_Node* feed2 = TF_GraphNodeByName(graph, "feed");
|
||||
EXPECT_TRUE(feed == feed2);
|
||||
ASSERT_TRUE(GetNodeDef(feed, &node_def));
|
||||
ASSERT_TRUE(GetNodeDef(feed2, &node_def2));
|
||||
@ -489,22 +487,22 @@ TEST(CAPI, Graph) {
|
||||
found_add = false;
|
||||
bool found_neg = false;
|
||||
size_t pos = 0;
|
||||
TF_Operation* oper;
|
||||
while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) {
|
||||
if (oper == feed) {
|
||||
TF_Node* node;
|
||||
while ((node = TF_GraphNextNode(graph, &pos)) != nullptr) {
|
||||
if (node == feed) {
|
||||
EXPECT_FALSE(found_placeholder);
|
||||
found_placeholder = true;
|
||||
} else if (oper == three) {
|
||||
} else if (node == three) {
|
||||
EXPECT_FALSE(found_scalar_const);
|
||||
found_scalar_const = true;
|
||||
} else if (oper == add) {
|
||||
} else if (node == add) {
|
||||
EXPECT_FALSE(found_add);
|
||||
found_add = true;
|
||||
} else if (oper == neg) {
|
||||
} else if (node == neg) {
|
||||
EXPECT_FALSE(found_neg);
|
||||
found_neg = true;
|
||||
} else {
|
||||
ASSERT_TRUE(GetNodeDef(oper, &node_def));
|
||||
ASSERT_TRUE(GetNodeDef(node, &node_def));
|
||||
ADD_FAILURE() << "Unexpected Node: " << ProtoDebugString(node_def);
|
||||
}
|
||||
}
|
||||
@ -534,7 +532,7 @@ class CSessionWithGraph {
|
||||
}
|
||||
|
||||
void SetInputs(
|
||||
std::initializer_list<std::pair<TF_Operation*, TF_Tensor*>> inputs) {
|
||||
std::initializer_list<std::pair<TF_Node*, TF_Tensor*>> inputs) {
|
||||
DeleteInputValues();
|
||||
inputs_.clear();
|
||||
for (const auto& p : inputs) {
|
||||
@ -543,17 +541,17 @@ class CSessionWithGraph {
|
||||
}
|
||||
}
|
||||
|
||||
void SetOutputs(std::initializer_list<TF_Operation*> outputs) {
|
||||
void SetOutputs(std::initializer_list<TF_Node*> outputs) {
|
||||
ResetOutputValues();
|
||||
outputs_.clear();
|
||||
for (TF_Operation* o : outputs) {
|
||||
for (TF_Node* o : outputs) {
|
||||
outputs_.emplace_back(TF_Port{o, 0});
|
||||
}
|
||||
}
|
||||
|
||||
void SetTargets(std::initializer_list<TF_Operation*> targets) {
|
||||
void SetTargets(std::initializer_list<TF_Node*> targets) {
|
||||
targets_.clear();
|
||||
for (TF_Operation* t : targets) {
|
||||
for (TF_Node* t : targets) {
|
||||
targets_.emplace_back(t);
|
||||
}
|
||||
}
|
||||
@ -574,8 +572,7 @@ class CSessionWithGraph {
|
||||
TF_Tensor** output_values_ptr =
|
||||
output_values_.empty() ? nullptr : &output_values_[0];
|
||||
|
||||
TF_Operation* const* targets_ptr =
|
||||
targets_.empty() ? nullptr : &targets_[0];
|
||||
TF_Node* const* targets_ptr = targets_.empty() ? nullptr : &targets_[0];
|
||||
|
||||
TF_SessionRun(session_, nullptr, inputs_ptr, input_values_ptr,
|
||||
inputs_.size(), outputs_ptr, output_values_ptr,
|
||||
@ -618,23 +615,23 @@ class CSessionWithGraph {
|
||||
std::vector<TF_Tensor*> input_values_;
|
||||
std::vector<TF_Port> outputs_;
|
||||
std::vector<TF_Tensor*> output_values_;
|
||||
std::vector<TF_Operation*> targets_;
|
||||
std::vector<TF_Node*> targets_;
|
||||
};
|
||||
|
||||
TEST(CAPI, SessionWithGraph) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_Graph* graph = TF_NewGraph();
|
||||
|
||||
// Make a placeholder operation.
|
||||
TF_Operation* feed = Placeholder(graph, s);
|
||||
// Make a placeholder node.
|
||||
TF_Node* feed = Placeholder(graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Make a constant operation with the scalar "2".
|
||||
TF_Operation* two = ScalarConst(2, graph, s);
|
||||
// Make a constant node with the scalar "2".
|
||||
TF_Node* two = ScalarConst(2, graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Add operation.
|
||||
TF_Operation* add = Add(feed, two, graph, s);
|
||||
// Add node.
|
||||
TF_Node* add = Add(feed, two, graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Create a session for this graph.
|
||||
@ -655,11 +652,11 @@ TEST(CAPI, SessionWithGraph) {
|
||||
static_cast<tensorflow::int32*>(TF_TensorData(out));
|
||||
EXPECT_EQ(3 + 2, *output_contents);
|
||||
|
||||
// Add another operation to the graph.
|
||||
TF_Operation* neg = Neg(add, graph, s);
|
||||
// Add another node to the graph.
|
||||
TF_Node* neg = Neg(add, graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Run up to the new operation.
|
||||
// Run up to the new node.
|
||||
csession.SetInputs({{feed, Int32Tensor(7)}});
|
||||
csession.SetOutputs({neg});
|
||||
csession.Run(s);
|
||||
|
Loading…
Reference in New Issue
Block a user