Prepare to remove a bunch of proto.h includes from tensorflow/core headers
The goal is to make kernels mostly independent of proto headers, which will let us lock down our .so imports. This CL does not remove any actual headers, but changes a bunch of files so that header removal is possible in a followup CL. It also marks the headers that will be removed with // TODO(b/62899350): Remove RELNOTES: n/a PiperOrigin-RevId: 160552878
This commit is contained in:
parent
9b11f45819
commit
e85d3df92d
@ -62,6 +62,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/cc:scope_internal",
|
"//tensorflow/cc:scope_internal",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
|
@ -28,12 +28,15 @@ limitations under the License.
|
|||||||
#endif
|
#endif
|
||||||
#include "tensorflow/c/c_api_internal.h"
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||||
|
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||||
#include "tensorflow/core/framework/log_memory.h"
|
#include "tensorflow/core/framework/log_memory.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
@ -1587,6 +1590,14 @@ void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
|
|||||||
|
|
||||||
// TF_Graph functions ---------------------------------------------------------
|
// TF_Graph functions ---------------------------------------------------------
|
||||||
|
|
||||||
|
TF_Graph::TF_Graph()
|
||||||
|
: graph(tensorflow::OpRegistry::Global()),
|
||||||
|
refiner(graph.versions().producer(), graph.op_registry()),
|
||||||
|
num_sessions(0),
|
||||||
|
delete_requested(false),
|
||||||
|
parent(nullptr),
|
||||||
|
parent_inputs(nullptr) {}
|
||||||
|
|
||||||
TF_Graph* TF_NewGraph() { return new TF_Graph; }
|
TF_Graph* TF_NewGraph() { return new TF_Graph; }
|
||||||
|
|
||||||
void TF_DeleteGraph(TF_Graph* g) {
|
void TF_DeleteGraph(TF_Graph* g) {
|
||||||
|
@ -56,13 +56,8 @@ struct TF_Library {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct TF_Graph {
|
struct TF_Graph {
|
||||||
TF_Graph()
|
TF_Graph();
|
||||||
: graph(tensorflow::OpRegistry::Global()),
|
|
||||||
refiner(graph.versions().producer(), graph.op_registry()),
|
|
||||||
num_sessions(0),
|
|
||||||
delete_requested(false),
|
|
||||||
parent(nullptr),
|
|
||||||
parent_inputs(nullptr) {}
|
|
||||||
tensorflow::mutex mu;
|
tensorflow::mutex mu;
|
||||||
tensorflow::Graph graph GUARDED_BY(mu);
|
tensorflow::Graph graph GUARDED_BY(mu);
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/graph/tensor_id.h"
|
#include "tensorflow/core/graph/tensor_id.h"
|
||||||
|
@ -45,6 +45,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:all_kernels",
|
"//tensorflow/core:all_kernels",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
@ -432,6 +433,7 @@ cc_library_with_android_deps(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:op_gen_lib",
|
"//tensorflow/core:op_gen_lib",
|
||||||
|
"//tensorflow/core:op_gen_overrides_proto_cc",
|
||||||
"//tensorflow/core:proto_text",
|
"//tensorflow/core:proto_text",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
],
|
],
|
||||||
|
@ -18,8 +18,12 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/cc/framework/cc_op_gen.h"
|
#include "tensorflow/cc/framework/cc_op_gen.h"
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/attr_value_util.h"
|
#include "tensorflow/core/framework/attr_value_util.h"
|
||||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||||
|
#include "tensorflow/core/framework/op_gen_overrides.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/framework/types.pb_text.h"
|
#include "tensorflow/core/framework/types.pb_text.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||||
#include "tensorflow/cc/framework/testutil.h"
|
#include "tensorflow/cc/framework/testutil.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
@ -136,7 +136,7 @@ Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
|
|||||||
Scope Scope::NewRootScope() {
|
Scope Scope::NewRootScope() {
|
||||||
Graph* graph = new Graph(OpRegistry::Global());
|
Graph* graph = new Graph(OpRegistry::Global());
|
||||||
ShapeRefiner* refiner =
|
ShapeRefiner* refiner =
|
||||||
new ShapeRefiner(graph->versions().producer(), graph->op_registry());
|
new ShapeRefiner(graph->versions(), graph->op_registry());
|
||||||
return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner));
|
return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/graph_def_util.h"
|
#include "tensorflow/core/framework/graph_def_util.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/algorithm.h"
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
@ -139,6 +139,7 @@ cc_library(
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
"//tensorflow/core:tensorflow_opensource",
|
"//tensorflow/core:tensorflow_opensource",
|
||||||
"//tensorflow/core/kernels:constant_op",
|
"//tensorflow/core/kernels:constant_op",
|
||||||
|
@ -31,9 +31,11 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/allocator.h"
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
#include "tensorflow/core/framework/device_base.h"
|
#include "tensorflow/core/framework/device_base.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
|
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/lib/core/notification.h"
|
#include "tensorflow/core/lib/core/notification.h"
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/common_runtime/local_device.h"
|
#include "tensorflow/core/common_runtime/local_device.h"
|
||||||
#include "tensorflow/core/framework/device_base.h"
|
#include "tensorflow/core/framework/device_base.h"
|
||||||
|
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/platform/mem.h"
|
#include "tensorflow/core/platform/mem.h"
|
||||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||||
|
@ -54,8 +54,8 @@ class SingleImageRandomDotStereogramsOp : public OpKernel {
|
|||||||
float normalize_min;
|
float normalize_min;
|
||||||
float border_level;
|
float border_level;
|
||||||
int number_colors;
|
int number_colors;
|
||||||
::tensorflow::TensorShapeProto output_image_shape;
|
::tensorflow::PartialTensorShape output_image_shape;
|
||||||
::tensorflow::TensorShapeProto output_data_window;
|
::tensorflow::PartialTensorShape output_data_window;
|
||||||
|
|
||||||
uint8 Cblack = 0;
|
uint8 Cblack = 0;
|
||||||
uint8 Cwhite = 255;
|
uint8 Cwhite = 255;
|
||||||
@ -109,15 +109,15 @@ class SingleImageRandomDotStereogramsOp : public OpKernel {
|
|||||||
input_Yvalue =
|
input_Yvalue =
|
||||||
input_tensor.shape().dim_size(0); // Y value is the number of rows
|
input_tensor.shape().dim_size(0); // Y value is the number of rows
|
||||||
|
|
||||||
output_Ximage = output_image_shape.dim(0).size();
|
output_Ximage = output_image_shape.dim_size(0);
|
||||||
output_Yimage = output_image_shape.dim(1).size();
|
output_Yimage = output_image_shape.dim_size(1);
|
||||||
output_Cimage = output_image_shape.dim(2).size();
|
output_Cimage = output_image_shape.dim_size(2);
|
||||||
|
|
||||||
if (number_colors > 256) // Go to full color image
|
if (number_colors > 256) // Go to full color image
|
||||||
output_Cimage = 3;
|
output_Cimage = 3;
|
||||||
|
|
||||||
int data_Xwindow = output_data_window.dim(0).size();
|
int data_Xwindow = output_data_window.dim_size(0);
|
||||||
int data_Ywindow = output_data_window.dim(1).size();
|
int data_Ywindow = output_data_window.dim_size(1);
|
||||||
|
|
||||||
int deltaX_border_image = output_Ximage - data_Xwindow;
|
int deltaX_border_image = output_Ximage - data_Xwindow;
|
||||||
int deltaY_border_image = output_Yimage - data_Ywindow;
|
int deltaY_border_image = output_Yimage - data_Ywindow;
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/platform/fingerprint.h"
|
#include "tensorflow/core/platform/fingerprint.h"
|
||||||
#include "tensorflow/core/util/work_sharder.h"
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
|
@ -29,10 +29,8 @@ REGISTER_OP("InfeedDequeue")
|
|||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
PartialTensorShape shape;
|
PartialTensorShape shape;
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
|
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
|
||||||
TensorShapeProto shape_proto;
|
|
||||||
shape.AsProto(&shape_proto);
|
|
||||||
ShapeHandle out;
|
ShapeHandle out;
|
||||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out));
|
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
|
||||||
c->set_output(0, out);
|
c->set_output(0, out);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
})
|
||||||
@ -87,10 +85,8 @@ REGISTER_OP("InfeedDequeueTuple")
|
|||||||
std::vector<PartialTensorShape> shapes;
|
std::vector<PartialTensorShape> shapes;
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
|
TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
|
||||||
for (int i = 0; i < shapes.size(); ++i) {
|
for (int i = 0; i < shapes.size(); ++i) {
|
||||||
TensorShapeProto shape_proto;
|
|
||||||
shapes[i].AsProto(&shape_proto);
|
|
||||||
ShapeHandle out;
|
ShapeHandle out;
|
||||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out));
|
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &out));
|
||||||
c->set_output(i, out);
|
c->set_output(i, out);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -15,11 +15,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h"
|
#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h"
|
||||||
|
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/kernels/immutable_constant_op.h"
|
#include "tensorflow/core/kernels/immutable_constant_op.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
@ -500,7 +500,6 @@ cc_library(
|
|||||||
# Generates library per group of ops.
|
# Generates library per group of ops.
|
||||||
tf_gen_op_libs(
|
tf_gen_op_libs(
|
||||||
op_lib_names = [
|
op_lib_names = [
|
||||||
"array_ops",
|
|
||||||
"bitwise_ops",
|
"bitwise_ops",
|
||||||
"candidate_sampling_ops",
|
"candidate_sampling_ops",
|
||||||
"control_flow_ops",
|
"control_flow_ops",
|
||||||
@ -534,6 +533,13 @@ tf_gen_op_libs(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_gen_op_libs(
|
||||||
|
op_lib_names = [
|
||||||
|
"array_ops",
|
||||||
|
],
|
||||||
|
deps = [":protos_all_cc"],
|
||||||
|
)
|
||||||
|
|
||||||
tf_gen_op_libs(
|
tf_gen_op_libs(
|
||||||
op_lib_names = [
|
op_lib_names = [
|
||||||
"audio_ops",
|
"audio_ops",
|
||||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/log_memory.h"
|
#include "tensorflow/core/framework/log_memory.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/algorithm.h"
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/algorithm.h"
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
#include "tensorflow/core/graph/gradients.h"
|
#include "tensorflow/core/graph/gradients.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/lib/core/notification.h"
|
#include "tensorflow/core/lib/core/notification.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/gpu/process_state.h"
|
#include "tensorflow/core/common_runtime/gpu/process_state.h"
|
||||||
#include "tensorflow/core/common_runtime/gpu_device_context.h"
|
#include "tensorflow/core/common_runtime/gpu_device_context.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor_reference.h"
|
#include "tensorflow/core/framework/tensor_reference.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/log_memory.h"
|
#include "tensorflow/core/framework/log_memory.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_util.h"
|
#include "tensorflow/core/framework/tensor_util.h"
|
||||||
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/algorithm.h"
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/op_segment.h"
|
#include "tensorflow/core/framework/op_segment.h"
|
||||||
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/kernels/ops_util.h"
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
#include "tensorflow/core/lib/core/notification.h"
|
#include "tensorflow/core/lib/core/notification.h"
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/kernels/bounds_check.h"
|
#include "tensorflow/core/kernels/bounds_check.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||||
@ -39,6 +40,10 @@ ShapeRefiner::ShapeRefiner(int graph_def_version,
|
|||||||
ops_registry_(ops),
|
ops_registry_(ops),
|
||||||
graph_runner_(Env::Default()) {}
|
graph_runner_(Env::Default()) {}
|
||||||
|
|
||||||
|
ShapeRefiner::ShapeRefiner(const VersionDef& versions,
|
||||||
|
const OpRegistryInterface* ops)
|
||||||
|
: ShapeRefiner(versions.producer(), ops) {}
|
||||||
|
|
||||||
ShapeRefiner::~ShapeRefiner() {
|
ShapeRefiner::~ShapeRefiner() {
|
||||||
// The lifetime of the tensors are bound to the GraphRunner, so the tensors
|
// The lifetime of the tensors are bound to the GraphRunner, so the tensors
|
||||||
// should be deleted before it.
|
// should be deleted before it.
|
||||||
|
@ -36,6 +36,10 @@ class GraphProperties;
|
|||||||
class ShapeRefiner {
|
class ShapeRefiner {
|
||||||
public:
|
public:
|
||||||
ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops);
|
ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops);
|
||||||
|
|
||||||
|
// Same as ShapeRefiner(versions.producer(), ops)
|
||||||
|
ShapeRefiner(const VersionDef& versions, const OpRegistryInterface* ops);
|
||||||
|
|
||||||
~ShapeRefiner();
|
~ShapeRefiner();
|
||||||
|
|
||||||
// Performs validation of 'node' and runs 'node's shape function,
|
// Performs validation of 'node' and runs 'node's shape function,
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/graph.pb_text.h"
|
#include "tensorflow/core/framework/graph.pb_text.h"
|
||||||
#include "tensorflow/core/framework/graph_def_util.h"
|
#include "tensorflow/core/framework/graph_def_util.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/graph/subgraph.h"
|
#include "tensorflow/core/graph/subgraph.h"
|
||||||
|
@ -15,7 +15,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||||
#include "tensorflow/core/common_runtime/costmodel_manager.h"
|
#include "tensorflow/core/common_runtime/costmodel_manager.h"
|
||||||
|
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_description.pb.h"
|
||||||
#include "tensorflow/core/graph/costmodel.h"
|
#include "tensorflow/core/graph/costmodel.h"
|
||||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
#include "tensorflow/core/lib/strings/scanner.h"
|
#include "tensorflow/core/lib/strings/scanner.h"
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
|
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/summary.pb.h"
|
#include "tensorflow/core/framework/summary.pb.h"
|
||||||
#include "tensorflow/core/lib/hash/hash.h"
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
|
@ -153,6 +153,7 @@ cc_library(
|
|||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:worker_proto_cc",
|
"//tensorflow/core:worker_proto_cc",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -205,6 +206,7 @@ cc_test(
|
|||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:tensor_testutil",
|
"//tensorflow/core:tensor_testutil",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/rendezvous.h"
|
#include "tensorflow/core/framework/rendezvous.h"
|
||||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||||
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/log_memory.h"
|
#include "tensorflow/core/framework/log_memory.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/graph/graph_partition.h"
|
#include "tensorflow/core/graph/graph_partition.h"
|
||||||
|
@ -26,11 +26,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/distributed_runtime/scheduler.h"
|
#include "tensorflow/core/distributed_runtime/scheduler.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||||
|
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||||
#include "tensorflow/core/framework/function.pb.h"
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_description.pb.h"
|
||||||
#include "tensorflow/core/graph/graph_partition.h"
|
#include "tensorflow/core/graph/graph_partition.h"
|
||||||
#include "tensorflow/core/graph/tensor_id.h"
|
#include "tensorflow/core/graph/tensor_id.h"
|
||||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||||
|
@ -108,6 +108,7 @@ cc_library(
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:worker_proto_cc",
|
"//tensorflow/core:worker_proto_cc",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++_unsecure",
|
||||||
],
|
],
|
||||||
|
@ -18,7 +18,9 @@ limitations under the License.
|
|||||||
#include "grpc++/support/slice.h"
|
#include "grpc++/support/slice.h"
|
||||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor_reference.h"
|
#include "tensorflow/core/framework/tensor_reference.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
#include "tensorflow/core/lib/io/proto_encode_helper.h"
|
#include "tensorflow/core/lib/io/proto_encode_helper.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "google/protobuf/any.pb.h"
|
#include "google/protobuf/any.pb.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
|
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||||
#include "tensorflow/core/framework/device_base.h"
|
#include "tensorflow/core/framework/device_base.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
@ -17,9 +17,11 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/example/feature.pb_text.h"
|
#include "tensorflow/core/example/feature.pb_text.h"
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/numeric_op.h"
|
#include "tensorflow/core/framework/numeric_op.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
@ -48,6 +48,14 @@ constexpr size_t Allocator::kAllocatorAlignment;
|
|||||||
|
|
||||||
Allocator::~Allocator() {}
|
Allocator::~Allocator() {}
|
||||||
|
|
||||||
|
void RunResourceCtor(ResourceHandle* p, size_t n) {
|
||||||
|
for (size_t i = 0; i < n; ++p, ++i) new (p) ResourceHandle();
|
||||||
|
}
|
||||||
|
|
||||||
|
void RunResourceDtor(ResourceHandle* p, size_t n) {
|
||||||
|
for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
|
||||||
|
}
|
||||||
|
|
||||||
// If true, cpu allocator collects more stats.
|
// If true, cpu allocator collects more stats.
|
||||||
static bool cpu_allocator_collect_stats = false;
|
static bool cpu_allocator_collect_stats = false;
|
||||||
// If true, cpu allocator collects full stats.
|
// If true, cpu allocator collects full stats.
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tensorflow/core/framework/attr_value.pb_text.h"
|
#include "tensorflow/core/framework/attr_value.pb_text.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb_text.h"
|
#include "tensorflow/core/framework/tensor.pb_text.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/framework/types.pb_text.h"
|
#include "tensorflow/core/framework/types.pb_text.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
@ -287,6 +288,8 @@ bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
|
|||||||
return ProtoParseFromString(to_parse, out);
|
return ProtoParseFromString(to_parse, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; }
|
||||||
|
|
||||||
#define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
|
#define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
|
||||||
void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); }
|
void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); }
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h" // TODO(62899350): Remove
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
@ -29,6 +29,10 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Forward declare protos so their symbols can be removed from .so exports
|
||||||
|
class AttrValue;
|
||||||
|
class NameAttrList;
|
||||||
|
|
||||||
// A human-readable rendering of attr_value, that is more concise than a
|
// A human-readable rendering of attr_value, that is more concise than a
|
||||||
// text-format proto.
|
// text-format proto.
|
||||||
string SummarizeAttrValue(const AttrValue& attr_value);
|
string SummarizeAttrValue(const AttrValue& attr_value);
|
||||||
@ -80,9 +84,7 @@ void SetAttrValue(gtl::ArraySlice<Tensor> value, AttrValue* out);
|
|||||||
void SetAttrValue(gtl::ArraySlice<TensorProto> value, AttrValue* out);
|
void SetAttrValue(gtl::ArraySlice<TensorProto> value, AttrValue* out);
|
||||||
void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out);
|
void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out);
|
||||||
|
|
||||||
inline void SetAttrValue(const AttrValue& value, AttrValue* out) {
|
void SetAttrValue(const AttrValue& value, AttrValue* out);
|
||||||
*out = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns true if a and b have the same value.
|
// Returns true if a and b have the same value.
|
||||||
// NOTE: May return false negatives for tensor values.
|
// NOTE: May return false negatives for tensor values.
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/attr_value_util.h"
|
#include "tensorflow/core/framework/attr_value_util.h"
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
@ -26,19 +26,11 @@ namespace shape_inference {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
TensorShapeProto S(std::initializer_list<int64> dims) {
|
PartialTensorShape S(std::initializer_list<int64> dims) {
|
||||||
PartialTensorShape shape(dims);
|
return PartialTensorShape(dims);
|
||||||
TensorShapeProto ret;
|
|
||||||
shape.AsProto(&ret);
|
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorShapeProto Unknown() {
|
PartialTensorShape Unknown() { return PartialTensorShape(); }
|
||||||
PartialTensorShape shape;
|
|
||||||
TensorShapeProto ret;
|
|
||||||
shape.AsProto(&ret);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
OpDef MakeOpDef(int num_inputs, int num_outputs) {
|
OpDef MakeOpDef(int num_inputs, int num_outputs) {
|
||||||
OpRegistrationData op_reg_data;
|
OpRegistrationData op_reg_data;
|
||||||
|
@ -19,4 +19,8 @@ namespace tensorflow {
|
|||||||
|
|
||||||
DeviceBase::~DeviceBase() {}
|
DeviceBase::~DeviceBase() {}
|
||||||
|
|
||||||
|
const DeviceAttributes& DeviceBase::attributes() const {
|
||||||
|
LOG(FATAL) << "Device does not implement attributes()";
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -19,9 +19,9 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
#include "tensorflow/core/framework/device_attributes.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/refcount.h"
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
@ -44,10 +44,12 @@ class Stream;
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
class Device;
|
class Device;
|
||||||
|
class DeviceAttributes;
|
||||||
class Env;
|
class Env;
|
||||||
class EventMgr;
|
class EventMgr;
|
||||||
class OpKernelContext;
|
class OpKernelContext;
|
||||||
class ResourceMgr;
|
class ResourceMgr;
|
||||||
|
class TensorProto;
|
||||||
|
|
||||||
namespace thread {
|
namespace thread {
|
||||||
class ThreadPool;
|
class ThreadPool;
|
||||||
@ -194,11 +196,8 @@ class DeviceBase {
|
|||||||
DeviceContext* /*dc*/,
|
DeviceContext* /*dc*/,
|
||||||
Allocator* /*allocator*/) {}
|
Allocator* /*allocator*/) {}
|
||||||
|
|
||||||
virtual const DeviceAttributes& attributes() const {
|
// Unimplemented by default
|
||||||
LOG(FATAL) << "Device does not implement attributes()";
|
virtual const DeviceAttributes& attributes() const;
|
||||||
static DeviceAttributes dummy;
|
|
||||||
return dummy;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Materializes the given TensorProto into 'tensor' stored in Device
|
// Materializes the given TensorProto into 'tensor' stored in Device
|
||||||
// memory. Most devices will want to override this.
|
// memory. Most devices will want to override this.
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/fake_input.h"
|
#include "tensorflow/core/framework/fake_input.h"
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/op_def.pb.h"
|
#include "tensorflow/core/framework/op_def.pb.h"
|
||||||
#include "tensorflow/core/framework/op_def_util.h"
|
#include "tensorflow/core/framework/op_def_util.h"
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/function.pb_text.h"
|
#include "tensorflow/core/framework/function.pb_text.h"
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
@ -17,9 +17,10 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_FRAMEWORK_FUNCTION_H_
|
#define TENSORFLOW_FRAMEWORK_FUNCTION_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/attr_value_util.h"
|
#include "tensorflow/core/framework/attr_value_util.h"
|
||||||
#include "tensorflow/core/framework/function.pb.h"
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/selective_registration.h"
|
#include "tensorflow/core/framework/selective_registration.h"
|
||||||
@ -33,6 +34,7 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
class CancellationManager;
|
class CancellationManager;
|
||||||
|
class GraphDef;
|
||||||
class OpKernel;
|
class OpKernel;
|
||||||
class ResourceMgr;
|
class ResourceMgr;
|
||||||
class ScopedStepContainer;
|
class ScopedStepContainer;
|
||||||
|
@ -20,7 +20,9 @@ limitations under the License.
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/function.pb.h"
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/op_def_util.h"
|
#include "tensorflow/core/framework/op_def_util.h"
|
||||||
|
@ -17,13 +17,15 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
|
#define TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Forward declare proto so that it's symbols can be removed from .so exports
|
||||||
|
class GraphDef;
|
||||||
|
|
||||||
// Produce a human-readable version of a GraphDef that is more concise
|
// Produce a human-readable version of a GraphDef that is more concise
|
||||||
// than a text-format proto.
|
// than a text-format proto.
|
||||||
string SummarizeGraphDef(const GraphDef& graph_def);
|
string SummarizeGraphDef(const GraphDef& graph_def);
|
||||||
|
@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/kernel_def.pb_text.h"
|
#include "tensorflow/core/framework/kernel_def.pb_text.h"
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -24,6 +25,10 @@ KernelDefBuilder::KernelDefBuilder(const char* op_name) {
|
|||||||
kernel_def_->set_op(op_name);
|
kernel_def_->set_op(op_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
KernelDefBuilder::~KernelDefBuilder() {
|
||||||
|
DCHECK(kernel_def_ == nullptr) << "Did not call Build()";
|
||||||
|
}
|
||||||
|
|
||||||
KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) {
|
KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) {
|
||||||
kernel_def_->set_device_type(device_type);
|
kernel_def_->set_device_type(device_type);
|
||||||
return *this;
|
return *this;
|
||||||
@ -61,4 +66,10 @@ KernelDefBuilder& KernelDefBuilder::Label(const char* label) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const KernelDef* KernelDefBuilder::Build() {
|
||||||
|
KernelDef* r = kernel_def_;
|
||||||
|
kernel_def_ = nullptr;
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
|
#ifndef TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
|
||||||
#define TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
|
#define TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
|
||||||
|
|
||||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
#include "tensorflow/core/framework/kernel_def.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
@ -24,16 +24,16 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Forward declare proto so that kernels don't need to depend on it
|
||||||
|
class KernelDef;
|
||||||
|
|
||||||
// Builder class passed to the REGISTER_KERNEL_BUILDER() macro.
|
// Builder class passed to the REGISTER_KERNEL_BUILDER() macro.
|
||||||
class KernelDefBuilder {
|
class KernelDefBuilder {
|
||||||
public:
|
public:
|
||||||
// Starts with just the name field set.
|
// Starts with just the name field set.
|
||||||
// Caller MUST call Build() and take ownership of the result.
|
// Caller MUST call Build() and take ownership of the result.
|
||||||
explicit KernelDefBuilder(const char* op_name);
|
explicit KernelDefBuilder(const char* op_name);
|
||||||
|
~KernelDefBuilder();
|
||||||
~KernelDefBuilder() {
|
|
||||||
DCHECK(kernel_def_ == nullptr) << "Did not call Build()";
|
|
||||||
}
|
|
||||||
|
|
||||||
// Required: specify the type of device this kernel supports.
|
// Required: specify the type of device this kernel supports.
|
||||||
// Returns *this.
|
// Returns *this.
|
||||||
@ -68,11 +68,7 @@ class KernelDefBuilder {
|
|||||||
// Returns a pointer to a KernelDef with fields set based on the
|
// Returns a pointer to a KernelDef with fields set based on the
|
||||||
// above calls to this instance.
|
// above calls to this instance.
|
||||||
// Caller takes ownership of the result.
|
// Caller takes ownership of the result.
|
||||||
const KernelDef* Build() {
|
const KernelDef* Build();
|
||||||
KernelDef* r = kernel_def_;
|
|
||||||
kernel_def_ = nullptr;
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
KernelDef* kernel_def_;
|
KernelDef* kernel_def_;
|
||||||
|
@ -16,12 +16,14 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_
|
#ifndef TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_
|
||||||
#define TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_
|
#define TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_
|
||||||
|
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class NodeDef;
|
||||||
|
|
||||||
// Returns into *{input,output}_memory_types the memory type of each
|
// Returns into *{input,output}_memory_types the memory type of each
|
||||||
// {input,output} tensor.
|
// {input,output} tensor.
|
||||||
//
|
//
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/op_def_util.h"
|
#include "tensorflow/core/framework/op_def_util.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
@ -83,6 +84,25 @@ NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NodeDefBuilder& NodeDefBuilder::Input(StringPiece src_node, int src_index,
|
||||||
|
DataType dt) {
|
||||||
|
const OpDef::ArgDef* arg = NextArgDef();
|
||||||
|
if (arg != nullptr) SingleInput(arg, src_node, src_index, dt);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
NodeDefBuilder& NodeDefBuilder::Input(const NodeOut& src) {
|
||||||
|
Input(src.node, src.index, src.data_type);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For inputs that take a list of tensors.
|
||||||
|
NodeDefBuilder& NodeDefBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
|
||||||
|
const OpDef::ArgDef* arg = NextArgDef();
|
||||||
|
if (arg != nullptr) ListInput(arg, src_list);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg,
|
void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg,
|
||||||
StringPiece src_node, int src_index,
|
StringPiece src_node, int src_index,
|
||||||
DataType dt) {
|
DataType dt) {
|
||||||
@ -228,14 +248,51 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def) const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void NodeDefBuilder::CheckInconsistency(StringPiece attr_name,
|
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) {
|
||||||
const AttrValue& found,
|
if (const AttrValue* found = AttrSlice(node_def_).Find(name)) {
|
||||||
const AttrValue& attr_value) {
|
if (!AreAttrValuesEqual(*found, value)) {
|
||||||
if (!AreAttrValuesEqual(found, attr_value)) {
|
errors_.push_back(strings::StrCat("Inconsistent values for attr '", name,
|
||||||
errors_.push_back(strings::StrCat(
|
"' ", SummarizeAttrValue(*found),
|
||||||
"Inconsistent values for attr '", attr_name, "' ",
|
" vs. ", SummarizeAttrValue(value)));
|
||||||
SummarizeAttrValue(found), " vs. ", SummarizeAttrValue(attr_value)));
|
}
|
||||||
|
} else {
|
||||||
|
AddNodeAttr(name, value, &node_def_);
|
||||||
}
|
}
|
||||||
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define ATTR(T) \
|
||||||
|
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, T value) { \
|
||||||
|
AttrValue attr_value; \
|
||||||
|
SetAttrValue(value, &attr_value); \
|
||||||
|
return Attr(name, attr_value); \
|
||||||
|
}
|
||||||
|
ATTR(StringPiece)
|
||||||
|
ATTR(const char*)
|
||||||
|
ATTR(int32)
|
||||||
|
ATTR(int64)
|
||||||
|
ATTR(float)
|
||||||
|
ATTR(double)
|
||||||
|
ATTR(bool)
|
||||||
|
ATTR(DataType)
|
||||||
|
ATTR(const PartialTensorShape&)
|
||||||
|
ATTR(const Tensor&)
|
||||||
|
ATTR(const TensorProto&)
|
||||||
|
ATTR(const NameAttrList&)
|
||||||
|
ATTR(gtl::ArraySlice<StringPiece>)
|
||||||
|
ATTR(gtl::ArraySlice<const char*>)
|
||||||
|
ATTR(gtl::ArraySlice<string>)
|
||||||
|
ATTR(gtl::ArraySlice<int32>)
|
||||||
|
ATTR(gtl::ArraySlice<int64>)
|
||||||
|
ATTR(gtl::ArraySlice<float>)
|
||||||
|
ATTR(gtl::ArraySlice<bool>)
|
||||||
|
ATTR(const std::vector<bool>&)
|
||||||
|
ATTR(gtl::ArraySlice<DataType>)
|
||||||
|
ATTR(gtl::ArraySlice<TensorShape>)
|
||||||
|
ATTR(gtl::ArraySlice<PartialTensorShape>)
|
||||||
|
ATTR(gtl::ArraySlice<TensorShapeProto>)
|
||||||
|
ATTR(gtl::ArraySlice<Tensor>)
|
||||||
|
ATTR(gtl::ArraySlice<NameAttrList>)
|
||||||
|
#undef ATTR
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
|||||||
#include <functional>
|
#include <functional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tensorflow/core/framework/attr_value_util.h"
|
#include "tensorflow/core/framework/attr_value_util.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
@ -72,22 +72,11 @@ class NodeDefBuilder {
|
|||||||
// *and in the same order as the input_args appear in the OpDef.*
|
// *and in the same order as the input_args appear in the OpDef.*
|
||||||
|
|
||||||
// For inputs that take a single tensor.
|
// For inputs that take a single tensor.
|
||||||
NodeDefBuilder& Input(StringPiece src_node, int src_index, DataType dt) {
|
NodeDefBuilder& Input(StringPiece src_node, int src_index, DataType dt);
|
||||||
const OpDef::ArgDef* arg = NextArgDef();
|
NodeDefBuilder& Input(const NodeOut& src);
|
||||||
if (arg != nullptr) SingleInput(arg, src_node, src_index, dt);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
NodeDefBuilder& Input(const NodeOut& src) {
|
|
||||||
Input(src.node, src.index, src.data_type);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
// For inputs that take a list of tensors.
|
// For inputs that take a list of tensors.
|
||||||
NodeDefBuilder& Input(gtl::ArraySlice<NodeOut> src_list) {
|
NodeDefBuilder& Input(gtl::ArraySlice<NodeOut> src_list);
|
||||||
const OpDef::ArgDef* arg = NextArgDef();
|
|
||||||
if (arg != nullptr) ListInput(arg, src_list);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
// To create inputs in tests, see fake_input.h.
|
// To create inputs in tests, see fake_input.h.
|
||||||
NodeDefBuilder& Input(FakeInputFunctor fake_input);
|
NodeDefBuilder& Input(FakeInputFunctor fake_input);
|
||||||
@ -100,13 +89,39 @@ class NodeDefBuilder {
|
|||||||
|
|
||||||
// Sets the attr, if not already set. If already set with a different
|
// Sets the attr, if not already set. If already set with a different
|
||||||
// value, an error will be returned from Finalize().
|
// value, an error will be returned from Finalize().
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, const AttrValue& value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, StringPiece value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, const char* value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, int32 value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, int64 value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, float value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, double value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, bool value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, DataType value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, const PartialTensorShape& value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, const Tensor& value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, const TensorProto& value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, const NameAttrList& value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<StringPiece> value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<const char*> value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<string> value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int32> value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int64> value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<float> value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<bool> value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, const std::vector<bool>& value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<DataType> value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<TensorShape> value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name,
|
||||||
|
gtl::ArraySlice<PartialTensorShape> value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name,
|
||||||
|
gtl::ArraySlice<TensorShapeProto> value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<Tensor> value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<NameAttrList> value);
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
NodeDefBuilder& Attr(StringPiece attr_name, T&& value);
|
NodeDefBuilder& Attr(StringPiece name, std::initializer_list<T> value) {
|
||||||
// Note: overload needed to allow {...} expressions for value.
|
return Attr(name, gtl::ArraySlice<T>(value));
|
||||||
template <class T>
|
|
||||||
NodeDefBuilder& Attr(StringPiece attr_name, std::initializer_list<T> value) {
|
|
||||||
Attr<std::initializer_list<T>>(attr_name, std::move(value));
|
|
||||||
return *this;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finish building the NodeDef, returning any errors or setting
|
// Finish building the NodeDef, returning any errors or setting
|
||||||
@ -152,9 +167,6 @@ class NodeDefBuilder {
|
|||||||
return input_arg->is_ref() ? MakeRefType(dt) : dt;
|
return input_arg->is_ref() ? MakeRefType(dt) : dt;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CheckInconsistency(StringPiece attr_name, const AttrValue& found,
|
|
||||||
const AttrValue& attr_value);
|
|
||||||
|
|
||||||
const OpDef* op_def_;
|
const OpDef* op_def_;
|
||||||
NodeDef node_def_;
|
NodeDef node_def_;
|
||||||
int inputs_specified_;
|
int inputs_specified_;
|
||||||
@ -162,21 +174,6 @@ class NodeDefBuilder {
|
|||||||
std::vector<string> errors_;
|
std::vector<string> errors_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// IMPLEMENTATION -------------------------------------------------------------
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece attr_name, T&& value) {
|
|
||||||
const AttrValue* found = AttrSlice(node_def_).Find(attr_name);
|
|
||||||
if (found == nullptr) {
|
|
||||||
AddNodeAttr(attr_name, std::forward<T>(value), &node_def_);
|
|
||||||
} else {
|
|
||||||
AttrValue attr_value;
|
|
||||||
SetAttrValue(std::forward<T>(value), &attr_value);
|
|
||||||
CheckInconsistency(attr_name, *found, attr_value);
|
|
||||||
}
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_
|
#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op_def.pb_text.h"
|
#include "tensorflow/core/framework/op_def.pb_text.h"
|
||||||
#include "tensorflow/core/framework/op_def_util.h"
|
#include "tensorflow/core/framework/op_def_util.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb_text.h"
|
#include "tensorflow/core/framework/tensor.pb_text.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/attr_value_util.h"
|
#include "tensorflow/core/framework/attr_value_util.h"
|
||||||
#include "tensorflow/core/framework/op_def.pb.h"
|
#include "tensorflow/core/framework/op_def.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
@ -30,8 +30,9 @@ namespace tensorflow {
|
|||||||
|
|
||||||
class Node;
|
class Node;
|
||||||
|
|
||||||
// We forward declare NodeDef so that kernels don't need to depend on protos
|
// We forward declare protos so that kernels don't need to depend on them
|
||||||
class NodeDef;
|
class NodeDef;
|
||||||
|
class OpDef;
|
||||||
|
|
||||||
// Name of the attribute used to encode node colocation constraints.
|
// Name of the attribute used to encode node colocation constraints.
|
||||||
//
|
//
|
||||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tensorflow/core/framework/op_def.pb.h"
|
#include "tensorflow/core/framework/op_def.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/op_def_builder.h"
|
#include "tensorflow/core/framework/op_def_builder.h"
|
||||||
#include "tensorflow/core/framework/op_def_util.h"
|
#include "tensorflow/core/framework/op_def_util.h"
|
||||||
#include "tensorflow/core/framework/selective_registration.h"
|
#include "tensorflow/core/framework/selective_registration.h"
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/attr_value_util.h"
|
#include "tensorflow/core/framework/attr_value_util.h"
|
||||||
#include "tensorflow/core/framework/op_def_util.h"
|
#include "tensorflow/core/framework/op_def_util.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <set>
|
#include <set>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/attr_value_util.h"
|
#include "tensorflow/core/framework/attr_value_util.h"
|
||||||
#include "tensorflow/core/framework/op_def.pb_text.h"
|
#include "tensorflow/core/framework/op_def.pb_text.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
|
#include "tensorflow/core/framework/op_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/op_gen_overrides.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
@ -71,6 +73,9 @@ bool ConsumeEquals(StringPiece* description) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OpGenOverrideMap::OpGenOverrideMap() {}
|
||||||
|
OpGenOverrideMap::~OpGenOverrideMap() {}
|
||||||
|
|
||||||
Status OpGenOverrideMap::LoadFileList(Env* env, const string& filenames) {
|
Status OpGenOverrideMap::LoadFileList(Env* env, const string& filenames) {
|
||||||
std::vector<string> v = str_util::Split(filenames, ",");
|
std::vector<string> v = str_util::Split(filenames, ",");
|
||||||
for (const string& f : v) {
|
for (const string& f : v) {
|
||||||
@ -86,7 +91,7 @@ Status OpGenOverrideMap::LoadFile(Env* env, const string& filename) {
|
|||||||
OpGenOverrides all;
|
OpGenOverrides all;
|
||||||
protobuf::TextFormat::ParseFromString(contents, &all);
|
protobuf::TextFormat::ParseFromString(contents, &all);
|
||||||
for (const auto& one : all.op()) {
|
for (const auto& one : all.op()) {
|
||||||
map_[one.name()] = one;
|
map_[one.name()].reset(new OpGenOverride(one));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -142,7 +147,7 @@ const OpGenOverride* OpGenOverrideMap::ApplyOverride(OpDef* op_def) const {
|
|||||||
// Look up
|
// Look up
|
||||||
const auto iter = map_.find(op_def->name());
|
const auto iter = map_.find(op_def->name());
|
||||||
if (iter == map_.end()) return nullptr;
|
if (iter == map_.end()) return nullptr;
|
||||||
const OpGenOverride& proto = iter->second;
|
const OpGenOverride& proto = *iter->second;
|
||||||
|
|
||||||
// Apply overrides from `proto`.
|
// Apply overrides from `proto`.
|
||||||
if (!proto.rename_to().empty()) {
|
if (!proto.rename_to().empty()) {
|
||||||
|
@ -18,14 +18,18 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include "tensorflow/core/framework/op_def.pb.h"
|
#include "tensorflow/core/framework/op_def.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/op_gen_overrides.pb.h"
|
#include "tensorflow/core/framework/op_gen_overrides.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Forward declare protos so their symbols can be removed from .so exports
|
||||||
|
class OpDef;
|
||||||
|
class OpGenOverride;
|
||||||
|
|
||||||
inline string Spaces(int n) { return string(n, ' '); }
|
inline string Spaces(int n) { return string(n, ' '); }
|
||||||
|
|
||||||
// Wrap prefix + str to be at most width characters, indenting every line
|
// Wrap prefix + str to be at most width characters, indenting every line
|
||||||
@ -43,6 +47,9 @@ bool ConsumeEquals(StringPiece* description);
|
|||||||
// look up the specific override for any given op.
|
// look up the specific override for any given op.
|
||||||
class OpGenOverrideMap {
|
class OpGenOverrideMap {
|
||||||
public:
|
public:
|
||||||
|
OpGenOverrideMap();
|
||||||
|
~OpGenOverrideMap();
|
||||||
|
|
||||||
// `filenames` is a comma-separated list of file names. If an op
|
// `filenames` is a comma-separated list of file names. If an op
|
||||||
// is mentioned in more than one file, the last one takes priority.
|
// is mentioned in more than one file, the last one takes priority.
|
||||||
Status LoadFileList(Env* env, const string& filenames);
|
Status LoadFileList(Env* env, const string& filenames);
|
||||||
@ -61,7 +68,7 @@ class OpGenOverrideMap {
|
|||||||
const OpGenOverride* ApplyOverride(OpDef* op_def) const;
|
const OpGenOverride* ApplyOverride(OpDef* op_def) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unordered_map<string, OpGenOverride> map_;
|
std::unordered_map<string, std::unique_ptr<OpGenOverride>> map_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/attr_value_util.h"
|
#include "tensorflow/core/framework/attr_value_util.h"
|
||||||
|
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||||
#include "tensorflow/core/framework/graph.pb_text.h"
|
#include "tensorflow/core/framework/graph.pb_text.h"
|
||||||
#include "tensorflow/core/framework/kernel_def.pb_text.h"
|
#include "tensorflow/core/framework/kernel_def.pb_text.h"
|
||||||
#include "tensorflow/core/framework/log_memory.h"
|
#include "tensorflow/core/framework/log_memory.h"
|
||||||
|
@ -24,19 +24,19 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/cancellation.h"
|
#include "tensorflow/core/framework/cancellation.h"
|
||||||
#include "tensorflow/core/framework/control_flow.h"
|
#include "tensorflow/core/framework/control_flow.h"
|
||||||
#include "tensorflow/core/framework/device_base.h"
|
#include "tensorflow/core/framework/device_base.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
#include "tensorflow/core/framework/kernel_def.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/rendezvous.h"
|
#include "tensorflow/core/framework/rendezvous.h"
|
||||||
#include "tensorflow/core/framework/selective_registration.h"
|
#include "tensorflow/core/framework/selective_registration.h"
|
||||||
#include "tensorflow/core/framework/session_state.h"
|
#include "tensorflow/core/framework/session_state.h"
|
||||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
#include "tensorflow/core/framework/step_stats.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
#include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/tracking_allocator.h"
|
#include "tensorflow/core/framework/tracking_allocator.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
@ -65,9 +65,13 @@ class TensorSliceReaderCacheWrapper;
|
|||||||
} // namespace checkpoint
|
} // namespace checkpoint
|
||||||
|
|
||||||
class AsyncOpKernel;
|
class AsyncOpKernel;
|
||||||
|
class FunctionCallFrame;
|
||||||
|
class FunctionLibraryRuntime;
|
||||||
class OpKernelConstruction; // declared below
|
class OpKernelConstruction; // declared below
|
||||||
class OpKernelContext; // declared below
|
class OpKernelContext; // declared below
|
||||||
|
class OpRegistryInterface;
|
||||||
class ResourceMgr;
|
class ResourceMgr;
|
||||||
|
class ScopedStepContainer;
|
||||||
|
|
||||||
class OpKernel {
|
class OpKernel {
|
||||||
public:
|
public:
|
||||||
|
@ -19,10 +19,12 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tensorflow/core/framework/allocator.h"
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/attr_value_util.h"
|
#include "tensorflow/core/framework/attr_value_util.h"
|
||||||
#include "tensorflow/core/framework/fake_input.h"
|
#include "tensorflow/core/framework/fake_input.h"
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/reader_base.h"
|
#include "tensorflow/core/framework/reader_base.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/reader_base.pb.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/core/coding.h"
|
#include "tensorflow/core/lib/core/coding.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
@ -19,12 +19,14 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "tensorflow/core/framework/queue_interface.h"
|
#include "tensorflow/core/framework/queue_interface.h"
|
||||||
#include "tensorflow/core/framework/reader_base.pb.h"
|
#include "tensorflow/core/framework/reader_base.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/reader_interface.h"
|
#include "tensorflow/core/framework/reader_interface.h"
|
||||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class ReaderBaseState;
|
||||||
|
|
||||||
// Default implementation of ReaderInterface.
|
// Default implementation of ReaderInterface.
|
||||||
class ReaderBase : public ReaderInterface {
|
class ReaderBase : public ReaderInterface {
|
||||||
public:
|
public:
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/resource_mgr.h"
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
@ -22,8 +22,9 @@ limitations under the License.
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/resource_handle.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/resource_mgr.h"
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/node_def.pb_text.h"
|
#include "tensorflow/core/framework/node_def.pb_text.h"
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/kernels/bounds_check.h"
|
#include "tensorflow/core/kernels/bounds_check.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/strings/numbers.h"
|
#include "tensorflow/core/lib/strings/numbers.h"
|
||||||
@ -79,6 +80,58 @@ InferenceContext::InferenceContext(
|
|||||||
PostInputInit(std::move(handle_data));
|
PostInputInit(std::move(handle_data));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Same as above, but with PartialTensorShape instead of TensorShapeProto
|
||||||
|
InferenceContext::InferenceContext(
|
||||||
|
int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
|
||||||
|
const std::vector<PartialTensorShape>& input_shapes,
|
||||||
|
const std::vector<const Tensor*>& input_tensors,
|
||||||
|
const std::vector<PartialTensorShape>& input_tensors_as_shapes,
|
||||||
|
const std::vector<
|
||||||
|
std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>&
|
||||||
|
input_handle_shapes_and_types)
|
||||||
|
: graph_def_version_(graph_def_version),
|
||||||
|
node_def_(*CHECK_NOTNULL(node_def)) {
|
||||||
|
std::vector<ShapeHandle> input_tensors_as_shape_handles;
|
||||||
|
for (const PartialTensorShape& p : input_tensors_as_shapes) {
|
||||||
|
ShapeHandle shape;
|
||||||
|
construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
|
||||||
|
if (!construction_status_.ok()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
input_tensors_as_shape_handles.push_back(shape);
|
||||||
|
}
|
||||||
|
PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
|
||||||
|
if (!construction_status_.ok()) return;
|
||||||
|
for (const PartialTensorShape& p : input_shapes) {
|
||||||
|
ShapeHandle shape;
|
||||||
|
construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
|
||||||
|
if (!construction_status_.ok()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
inputs_.push_back(shape);
|
||||||
|
}
|
||||||
|
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
|
||||||
|
input_shapes.size());
|
||||||
|
for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) {
|
||||||
|
const auto& v = input_handle_shapes_and_types[i];
|
||||||
|
if (v == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
|
||||||
|
auto& new_v = *handle_data[i];
|
||||||
|
for (int j = 0; j < v->size(); ++j) {
|
||||||
|
const auto& p = (*v)[j];
|
||||||
|
construction_status_.Update(
|
||||||
|
MakeShapeFromPartialTensorShape(p.first, &new_v[j].shape));
|
||||||
|
if (!construction_status_.ok()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
new_v[j].dtype = p.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PostInputInit(std::move(handle_data));
|
||||||
|
}
|
||||||
|
|
||||||
InferenceContext::InferenceContext(
|
InferenceContext::InferenceContext(
|
||||||
int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
|
int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
|
||||||
const std::vector<ShapeHandle>& input_shapes,
|
const std::vector<ShapeHandle>& input_shapes,
|
||||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
@ -189,6 +189,26 @@ class InferenceContext {
|
|||||||
std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>&
|
std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>&
|
||||||
input_handle_shapes_and_types);
|
input_handle_shapes_and_types);
|
||||||
|
|
||||||
|
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
|
||||||
|
//
|
||||||
|
// Elements of <input_tensors_as_shapes> are used for when a shape
|
||||||
|
// function makes a call to MakeShapeFromShapeTensor; in particular, when
|
||||||
|
// the input_tensors[i] is nullptr but the shape represented by it is
|
||||||
|
// partially known from analysis of the graph. <input_tensors_as_shapes>
|
||||||
|
// can have fewer elements than <input_shapes>. Values of
|
||||||
|
// <input_tensors_as_shapes> do not need to outlive the context.
|
||||||
|
//
|
||||||
|
// REQUIRES: <node_def> is not NULL, and must outlive the
|
||||||
|
// InferenceContext.
|
||||||
|
InferenceContext(
|
||||||
|
int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
|
||||||
|
const std::vector<PartialTensorShape>& input_shapes,
|
||||||
|
const std::vector<const Tensor*>& input_tensors,
|
||||||
|
const std::vector<PartialTensorShape>& input_tensors_as_shapes,
|
||||||
|
const std::vector<std::unique_ptr<
|
||||||
|
std::vector<std::pair<PartialTensorShape, DataType>>>>&
|
||||||
|
input_handle_shapes_and_types);
|
||||||
|
|
||||||
~InferenceContext();
|
~InferenceContext();
|
||||||
|
|
||||||
// Runs the shape inference function 'fn' with 'this' as the
|
// Runs the shape inference function 'fn' with 'this' as the
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/fake_input.h"
|
#include "tensorflow/core/framework/fake_input.h"
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
#include "tensorflow/core/framework/op_def_builder.h"
|
#include "tensorflow/core/framework/op_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
@ -36,19 +37,11 @@ OpDef MakeOpDefWithLists() {
|
|||||||
return op_reg_data.op_def;
|
return op_reg_data.op_def;
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorShapeProto S(std::initializer_list<int64> dims) {
|
PartialTensorShape S(std::initializer_list<int64> dims) {
|
||||||
PartialTensorShape shape(dims);
|
return PartialTensorShape(dims);
|
||||||
TensorShapeProto ret;
|
|
||||||
shape.AsProto(&ret);
|
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorShapeProto Unknown() {
|
PartialTensorShape Unknown() { return PartialTensorShape(); }
|
||||||
PartialTensorShape shape;
|
|
||||||
TensorShapeProto ret;
|
|
||||||
shape.AsProto(&ret);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -1537,7 +1530,7 @@ void ShapeInferenceTest::TestMergeHandles(bool input_not_output) {
|
|||||||
{});
|
{});
|
||||||
auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
|
auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
|
||||||
ShapeHandle s;
|
ShapeHandle s;
|
||||||
TF_CHECK_OK(c.MakeShapeFromShapeProto(S(dim_sizes), &s));
|
TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
|
||||||
return s;
|
return s;
|
||||||
};
|
};
|
||||||
auto get_shapes_and_types_from_context = [&](int idx) {
|
auto get_shapes_and_types_from_context = [&](int idx) {
|
||||||
@ -1648,7 +1641,7 @@ void ShapeInferenceTest::TestRelaxHandles(bool input_not_output) {
|
|||||||
{});
|
{});
|
||||||
auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
|
auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
|
||||||
ShapeHandle s;
|
ShapeHandle s;
|
||||||
TF_CHECK_OK(c.MakeShapeFromShapeProto(S(dim_sizes), &s));
|
TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
|
||||||
return s;
|
return s;
|
||||||
};
|
};
|
||||||
auto get_shapes_and_types_from_context = [&](int idx) {
|
auto get_shapes_and_types_from_context = [&](int idx) {
|
||||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
|||||||
#define THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_
|
#define THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
@ -28,7 +27,6 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
class NodeDef;
|
|
||||||
class Tensor;
|
class Tensor;
|
||||||
|
|
||||||
struct ShapeInferenceTestOp {
|
struct ShapeInferenceTestOp {
|
||||||
|
@ -29,9 +29,11 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||||
#include "tensorflow/core/framework/log_memory.h"
|
#include "tensorflow/core/framework/log_memory.h"
|
||||||
#include "tensorflow/core/framework/resource_handle.pb.h"
|
#include "tensorflow/core/framework/resource_handle.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_description.pb.h"
|
||||||
#include "tensorflow/core/framework/type_traits.h"
|
#include "tensorflow/core/framework/type_traits.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/core/coding.h"
|
#include "tensorflow/core/lib/core/coding.h"
|
||||||
|
@ -17,10 +17,10 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_
|
#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
#include "tensorflow/core/framework/allocation_description.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/allocator.h"
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/tensor_description.pb.h"
|
#include "tensorflow/core/framework/tensor_description.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
@ -35,8 +35,13 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
class TensorBuffer; // Forward declaration.
|
// Forward declarations. In particular, we forward declare protos so that their
|
||||||
|
// symbols can be removed from .so exports.
|
||||||
|
class AllocationDescription;
|
||||||
|
class TensorBuffer;
|
||||||
class TensorCApi;
|
class TensorCApi;
|
||||||
|
class TensorDescription;
|
||||||
|
class TensorProto;
|
||||||
|
|
||||||
/// @ingroup core
|
/// @ingroup core
|
||||||
/// Represents an n-dimensional array of values.
|
/// Represents an n-dimensional array of values.
|
||||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_FRAMEWORK_TENSOR_REFERENCE_H_
|
#ifndef TENSORFLOW_FRAMEWORK_TENSOR_REFERENCE_H_
|
||||||
#define TENSORFLOW_FRAMEWORK_TENSOR_REFERENCE_H_
|
#define TENSORFLOW_FRAMEWORK_TENSOR_REFERENCE_H_
|
||||||
|
|
||||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/kernels/bounds_check.h"
|
#include "tensorflow/core/kernels/bounds_check.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
#include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
@ -35,6 +35,7 @@ namespace tensorflow {
|
|||||||
template <class Shape>
|
template <class Shape>
|
||||||
class TensorShapeIter;
|
class TensorShapeIter;
|
||||||
class TensorShape;
|
class TensorShape;
|
||||||
|
class TensorShapeProto;
|
||||||
class PartialTensorShape;
|
class PartialTensorShape;
|
||||||
// END_SKIP_DOXYGEN
|
// END_SKIP_DOXYGEN
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/framework/versions.h"
|
#include "tensorflow/core/framework/versions.h"
|
||||||
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/public/version.h"
|
#include "tensorflow/core/public/version.h"
|
||||||
|
|
||||||
|
@ -16,11 +16,13 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_FRAMEWORK_VERSIONS_H_
|
#ifndef TENSORFLOW_FRAMEWORK_VERSIONS_H_
|
||||||
#define TENSORFLOW_FRAMEWORK_VERSIONS_H_
|
#define TENSORFLOW_FRAMEWORK_VERSIONS_H_
|
||||||
|
|
||||||
#include "tensorflow/core/framework/versions.pb.h"
|
#include "tensorflow/core/framework/versions.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class VersionDef;
|
||||||
|
|
||||||
// Check whether data with the given versions is compatible with the given
|
// Check whether data with the given versions is compatible with the given
|
||||||
// consumer and min producer. upper_name and lower_name are used to form
|
// consumer and min producer. upper_name and lower_name are used to form
|
||||||
// error messages upon failure. Example usage:
|
// error messages upon failure. Example usage:
|
||||||
|
@ -16,8 +16,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/costmodel.h"
|
#include "tensorflow/core/graph/costmodel.h"
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_description.pb.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
|
@ -16,9 +16,11 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
@ -255,9 +257,11 @@ Status Node::input_node(int idx, const Node** const_n) const {
|
|||||||
// Graph
|
// Graph
|
||||||
|
|
||||||
Graph::Graph(const OpRegistryInterface* ops)
|
Graph::Graph(const OpRegistryInterface* ops)
|
||||||
: ops_(ops, FunctionDefLibrary()), arena_(8 << 10 /* 8kB */) {
|
: ops_(ops, FunctionDefLibrary()),
|
||||||
versions_.set_producer(TF_GRAPH_DEF_VERSION);
|
versions_(new VersionDef),
|
||||||
versions_.set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
|
arena_(8 << 10 /* 8kB */) {
|
||||||
|
versions_->set_producer(TF_GRAPH_DEF_VERSION);
|
||||||
|
versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
|
||||||
|
|
||||||
// Initialize the name interning table for assigned_device_name.
|
// Initialize the name interning table for assigned_device_name.
|
||||||
device_names_.push_back("");
|
device_names_.push_back("");
|
||||||
@ -301,6 +305,9 @@ Graph::~Graph() {
|
|||||||
// destroy them.
|
// destroy them.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const VersionDef& Graph::versions() const { return *versions_; }
|
||||||
|
void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
|
||||||
|
|
||||||
Node* Graph::AddNode(const NodeDef& node_def, Status* status) {
|
Node* Graph::AddNode(const NodeDef& node_def, Status* status) {
|
||||||
const OpDef* op_def;
|
const OpDef* op_def;
|
||||||
status->Update(ops_.LookUpOpDef(node_def.op(), &op_def));
|
status->Update(ops_.LookUpOpDef(node_def.op(), &op_def));
|
||||||
|
@ -41,10 +41,10 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/framework/versions.pb.h"
|
#include "tensorflow/core/framework/versions.pb.h" // TODO(b/62899350): Remove
|
||||||
#include "tensorflow/core/graph/edgeset.h"
|
#include "tensorflow/core/graph/edgeset.h"
|
||||||
#include "tensorflow/core/lib/core/arena.h"
|
#include "tensorflow/core/lib/core/arena.h"
|
||||||
#include "tensorflow/core/lib/core/refcount.h"
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
@ -59,7 +59,9 @@ namespace tensorflow {
|
|||||||
class Edge;
|
class Edge;
|
||||||
class EdgeSetTest;
|
class EdgeSetTest;
|
||||||
class Graph;
|
class Graph;
|
||||||
|
class GraphDef;
|
||||||
class Node;
|
class Node;
|
||||||
|
class VersionDef;
|
||||||
|
|
||||||
class NeighborIter; // Declared below
|
class NeighborIter; // Declared below
|
||||||
class NodeIter; // Declared below
|
class NodeIter; // Declared below
|
||||||
@ -370,8 +372,8 @@ class Graph {
|
|||||||
static const int kControlSlot;
|
static const int kControlSlot;
|
||||||
|
|
||||||
// The GraphDef version range of this graph (see graph.proto).
|
// The GraphDef version range of this graph (see graph.proto).
|
||||||
const VersionDef& versions() const { return versions_; }
|
const VersionDef& versions() const;
|
||||||
void set_versions(const VersionDef& versions) { versions_ = versions; }
|
void set_versions(const VersionDef& versions);
|
||||||
|
|
||||||
// Adds a new node to this graph, and returns it. Infers the Op and
|
// Adds a new node to this graph, and returns it. Infers the Op and
|
||||||
// input/output types for the node. *this owns the returned instance.
|
// input/output types for the node. *this owns the returned instance.
|
||||||
@ -514,7 +516,7 @@ class Graph {
|
|||||||
FunctionLibraryDefinition ops_;
|
FunctionLibraryDefinition ops_;
|
||||||
|
|
||||||
// GraphDef versions
|
// GraphDef versions
|
||||||
VersionDef versions_;
|
const std::unique_ptr<VersionDef> versions_;
|
||||||
|
|
||||||
// Allocator which will give us good locality.
|
// Allocator which will give us good locality.
|
||||||
core::Arena arena_;
|
core::Arena arena_;
|
||||||
|
@ -27,8 +27,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/framework/versions.h"
|
#include "tensorflow/core/framework/versions.h"
|
||||||
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/algorithm.h"
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/tensor_id.h"
|
#include "tensorflow/core/graph/tensor_id.h"
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
#include "tensorflow/core/kernels/ops_util.h"
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/kernels/ops_util.h"
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
@ -22,7 +22,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/memory_types.h"
|
#include "tensorflow/core/framework/memory_types.h"
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/algorithm.h"
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
#include "tensorflow/core/graph/control_flow.h"
|
#include "tensorflow/core/graph/control_flow.h"
|
||||||
#include "tensorflow/core/graph/costmodel.h"
|
#include "tensorflow/core/graph/costmodel.h"
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
#include "tensorflow/core/framework/function_testlib.h"
|
#include "tensorflow/core/framework/function_testlib.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/grappler/grappler_item.h"
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||||
#include "tensorflow/core/grappler/utils.h"
|
#include "tensorflow/core/grappler/utils.h"
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
|
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
|
||||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
|
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
|
||||||
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
|
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
|
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
|
||||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/grappler/grappler_item.h"
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
@ -60,6 +60,7 @@ cc_test(
|
|||||||
"//tensorflow/cc:scope",
|
"//tensorflow/cc:scope",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib_proto_parsing",
|
"//tensorflow/core:lib_proto_parsing",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:tensor_testutil",
|
"//tensorflow/core:tensor_testutil",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
@ -196,6 +197,7 @@ cc_test(
|
|||||||
":virtual_placer",
|
":virtual_placer",
|
||||||
":virtual_scheduler",
|
":virtual_scheduler",
|
||||||
"//tensorflow/cc:cc_ops",
|
"//tensorflow/cc:cc_ops",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:tensorflow",
|
"//tensorflow/core:tensorflow",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
@ -232,6 +234,7 @@ cc_library(
|
|||||||
":cost_estimator",
|
":cost_estimator",
|
||||||
":op_performance_data_cc",
|
":op_performance_data_cc",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler/clusters:utils",
|
"//tensorflow/core/grappler/clusters:utils",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
],
|
],
|
||||||
@ -265,6 +268,7 @@ cc_library(
|
|||||||
":virtual_scheduler",
|
":virtual_scheduler",
|
||||||
"//tensorflow/core:core_cpu_base",
|
"//tensorflow/core:core_cpu_base",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler:grappler_item",
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/graph/types.h"
|
#include "tensorflow/core/graph/types.h"
|
||||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||||
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
|
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
|
||||||
|
@ -179,7 +179,7 @@ Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
|
|||||||
|
|
||||||
Status GraphProperties::InferStatically() {
|
Status GraphProperties::InferStatically() {
|
||||||
Graph graph(OpRegistry::Global());
|
Graph graph(OpRegistry::Global());
|
||||||
ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry());
|
ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
|
||||||
shape_refiner.set_require_shape_inference_fns(false);
|
shape_refiner.set_require_shape_inference_fns(false);
|
||||||
ImportGraphDefOptions options;
|
ImportGraphDefOptions options;
|
||||||
Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner);
|
Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner);
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user