185 lines
6.7 KiB
C++
185 lines
6.7 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/c/python_api.h"
|
|
|
|
#include "tensorflow/c/c_api_internal.h"
|
|
#include "tensorflow/python/framework/cpp_shape_inference.pb.h"
|
|
|
|
namespace tensorflow {
|
|
|
|
void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
|
|
mutex_lock l(graph->mu);
|
|
graph->graph.AddControlEdge(&input->node, &op->node);
|
|
RecordMutation(graph, *op, "adding control input");
|
|
}
|
|
|
|
void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
|
|
TF_Buffer* attr_value_proto, TF_Status* status) {
|
|
AttrValue attr_val;
|
|
if (!attr_val.ParseFromArray(attr_value_proto->data,
|
|
attr_value_proto->length)) {
|
|
status->status =
|
|
tensorflow::errors::InvalidArgument("Invalid AttrValue proto");
|
|
return;
|
|
}
|
|
|
|
mutex_lock l(graph->mu);
|
|
op->node.AddAttr(attr_name, attr_val);
|
|
RecordMutation(graph, *op, "setting attribute");
|
|
}
|
|
|
|
void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
|
|
TF_Status* status) {
|
|
|
|
mutex_lock l(graph->mu);
|
|
op->node.ClearAttr(attr_name);
|
|
RecordMutation(graph, *op, "clearing attribute");
|
|
}
|
|
|
|
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
|
|
mutex_lock l(graph->mu);
|
|
op->node.set_requested_device(device);
|
|
RecordMutation(graph, *op, "setting device");
|
|
}
|
|
|
|
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
|
TF_Status* status) {
|
|
mutex_lock l(graph->mu);
|
|
tensorflow::shape_inference::InferenceContext* ic =
|
|
graph->refiner.GetContext(&new_src.oper->node);
|
|
|
|
if (ic->num_outputs() <= new_src.index) {
|
|
status->status = tensorflow::errors::OutOfRange(
|
|
"Cannot update edge. Output index [", new_src.index,
|
|
"] is greater than the number of total outputs [", ic->num_outputs(),
|
|
"].");
|
|
return;
|
|
}
|
|
tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
|
|
|
|
tensorflow::shape_inference::InferenceContext* ic_dst =
|
|
graph->refiner.GetContext(&dst.oper->node);
|
|
if (ic_dst->num_inputs() <= dst.index) {
|
|
status->status = tensorflow::errors::OutOfRange(
|
|
"Cannot update edge. Input index [", dst.index,
|
|
"] is greater than the number of total inputs [", ic_dst->num_inputs(),
|
|
"].");
|
|
return;
|
|
}
|
|
if (!ic_dst->MergeInput(dst.index, shape)) {
|
|
status->status = tensorflow::errors::InvalidArgument(
|
|
"Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
|
|
" and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
|
|
return;
|
|
}
|
|
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
|
|
&dst.oper->node, dst.index);
|
|
|
|
if (TF_GetCode(status) == TF_OK) {
|
|
// This modification only updates the destination node for
|
|
// the purposes of running this graph in a session. Thus, we don't
|
|
// record the source node as being modified.
|
|
RecordMutation(graph, *dst.oper, "updating input tensor");
|
|
}
|
|
}
|
|
|
|
void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
|
|
mutex_lock l(graph->mu);
|
|
std::vector<const Edge*> control_edges;
|
|
for (const Edge* edge : op->node.in_edges()) {
|
|
if (!edge->IsControlEdge()) continue;
|
|
control_edges.push_back(edge);
|
|
}
|
|
for (const Edge* edge : control_edges) {
|
|
graph->graph.RemoveControlEdge(edge);
|
|
}
|
|
}
|
|
|
|
void SetRequireShapeInferenceFns(TF_Graph* graph, bool require) {
|
|
mutex_lock l(graph->mu);
|
|
graph->refiner.set_require_shape_inference_fns(require);
|
|
}
|
|
|
|
void ExtendSession(TF_Session* session, TF_Status* status) {
|
|
ExtendSessionGraphHelper(session, status);
|
|
session->extend_before_run = false;
|
|
}
|
|
|
|
std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
|
|
Node* node = &output.oper->node;
|
|
CppShapeInferenceResult::HandleData handle_data;
|
|
handle_data.set_is_set(true);
|
|
{
|
|
mutex_lock l(graph->mu);
|
|
tensorflow::shape_inference::InferenceContext* ic =
|
|
graph->refiner.GetContext(node);
|
|
CHECK(ic != nullptr);
|
|
CHECK_LT(output.index, ic->num_outputs());
|
|
const auto* shapes_and_types =
|
|
ic->output_handle_shapes_and_types(output.index);
|
|
if (shapes_and_types == nullptr) return "";
|
|
|
|
for (const auto& p : *shapes_and_types) {
|
|
auto* out_shape_and_type = handle_data.add_shape_and_type();
|
|
ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
|
|
out_shape_and_type->set_dtype(p.dtype);
|
|
}
|
|
}
|
|
string result;
|
|
handle_data.SerializeToString(&result);
|
|
return result;
|
|
}
|
|
|
|
void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
|
|
size_t proto_len, TF_Status* status) {
|
|
tensorflow::CppShapeInferenceResult::HandleData handle_data;
|
|
if (!handle_data.ParseFromArray(proto, proto_len)) {
|
|
status->status = tensorflow::errors::InvalidArgument(
|
|
"Couldn't deserialize HandleData proto");
|
|
return;
|
|
}
|
|
DCHECK(handle_data.is_set());
|
|
|
|
tensorflow::mutex_lock l(graph->mu);
|
|
tensorflow::shape_inference::InferenceContext* ic =
|
|
graph->refiner.GetContext(&output.oper->node);
|
|
|
|
std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
|
|
for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
|
|
tensorflow::shape_inference::ShapeHandle shape;
|
|
status->status =
|
|
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
|
|
if (TF_GetCode(status) != TF_OK) return;
|
|
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
|
|
}
|
|
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
|
|
}
|
|
|
|
void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
|
|
TF_Status* status) {
|
|
mutex_lock l(graph->mu);
|
|
status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
|
|
new_src.index, &dst->node);
|
|
if (TF_GetCode(status) == TF_OK) {
|
|
// This modification only updates the destination node for
|
|
// the purposes of running this graph in a session. Thus, we don't
|
|
// record the source node as being modified.
|
|
RecordMutation(graph, *dst, "adding input tensor");
|
|
}
|
|
}
|
|
|
|
} // namespace tensorflow
|