Prepare to remove a bunch of proto.h includes from tensorflow/core headers
The goal is to make kernels mostly independent of proto headers, which will let us lock down our .so imports. This CL does not remove any actual headers, but changes a bunch of files so that header removal is possible in a followup CL. It also marks the headers that will be removed with // TODO(b/62899350): Remove RELNOTES: n/a PiperOrigin-RevId: 160552878
This commit is contained in:
parent
9b11f45819
commit
e85d3df92d
@ -62,6 +62,7 @@ tf_cuda_library(
|
||||
"//tensorflow/cc:scope_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
}),
|
||||
|
@ -28,12 +28,15 @@ limitations under the License.
|
||||
#endif
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/log_memory.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
@ -1587,6 +1590,14 @@ void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
|
||||
|
||||
// TF_Graph functions ---------------------------------------------------------
|
||||
|
||||
TF_Graph::TF_Graph()
|
||||
: graph(tensorflow::OpRegistry::Global()),
|
||||
refiner(graph.versions().producer(), graph.op_registry()),
|
||||
num_sessions(0),
|
||||
delete_requested(false),
|
||||
parent(nullptr),
|
||||
parent_inputs(nullptr) {}
|
||||
|
||||
TF_Graph* TF_NewGraph() { return new TF_Graph; }
|
||||
|
||||
void TF_DeleteGraph(TF_Graph* g) {
|
||||
|
@ -56,13 +56,8 @@ struct TF_Library {
|
||||
};
|
||||
|
||||
struct TF_Graph {
|
||||
TF_Graph()
|
||||
: graph(tensorflow::OpRegistry::Global()),
|
||||
refiner(graph.versions().producer(), graph.op_registry()),
|
||||
num_sessions(0),
|
||||
delete_requested(false),
|
||||
parent(nullptr),
|
||||
parent_inputs(nullptr) {}
|
||||
TF_Graph();
|
||||
|
||||
tensorflow::mutex mu;
|
||||
tensorflow::Graph graph GUARDED_BY(mu);
|
||||
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
|
@ -45,6 +45,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
@ -432,6 +433,7 @@ cc_library_with_android_deps(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
"//tensorflow/core:op_gen_overrides_proto_cc",
|
||||
"//tensorflow/core:proto_text",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
|
@ -18,8 +18,12 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/framework/cc_op_gen.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/framework/op_gen_overrides.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb_text.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/framework/testutil.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
@ -136,7 +136,7 @@ Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
|
||||
Scope Scope::NewRootScope() {
|
||||
Graph* graph = new Graph(OpRegistry::Global());
|
||||
ShapeRefiner* refiner =
|
||||
new ShapeRefiner(graph->versions().producer(), graph->op_registry());
|
||||
new ShapeRefiner(graph->versions(), graph->op_registry());
|
||||
return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner));
|
||||
}
|
||||
|
||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
|
@ -139,6 +139,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core:tensorflow_opensource",
|
||||
"//tensorflow/core/kernels:constant_op",
|
||||
|
@ -31,9 +31,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/local_device.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
@ -54,8 +54,8 @@ class SingleImageRandomDotStereogramsOp : public OpKernel {
|
||||
float normalize_min;
|
||||
float border_level;
|
||||
int number_colors;
|
||||
::tensorflow::TensorShapeProto output_image_shape;
|
||||
::tensorflow::TensorShapeProto output_data_window;
|
||||
::tensorflow::PartialTensorShape output_image_shape;
|
||||
::tensorflow::PartialTensorShape output_data_window;
|
||||
|
||||
uint8 Cblack = 0;
|
||||
uint8 Cwhite = 255;
|
||||
@ -109,15 +109,15 @@ class SingleImageRandomDotStereogramsOp : public OpKernel {
|
||||
input_Yvalue =
|
||||
input_tensor.shape().dim_size(0); // Y value is the number of rows
|
||||
|
||||
output_Ximage = output_image_shape.dim(0).size();
|
||||
output_Yimage = output_image_shape.dim(1).size();
|
||||
output_Cimage = output_image_shape.dim(2).size();
|
||||
output_Ximage = output_image_shape.dim_size(0);
|
||||
output_Yimage = output_image_shape.dim_size(1);
|
||||
output_Cimage = output_image_shape.dim_size(2);
|
||||
|
||||
if (number_colors > 256) // Go to full color image
|
||||
output_Cimage = 3;
|
||||
|
||||
int data_Xwindow = output_data_window.dim(0).size();
|
||||
int data_Ywindow = output_data_window.dim(1).size();
|
||||
int data_Xwindow = output_data_window.dim_size(0);
|
||||
int data_Ywindow = output_data_window.dim_size(1);
|
||||
|
||||
int deltaX_border_image = output_Ximage - data_Xwindow;
|
||||
int deltaY_border_image = output_Yimage - data_Ywindow;
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
|
@ -29,10 +29,8 @@ REGISTER_OP("InfeedDequeue")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
PartialTensorShape shape;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
|
||||
TensorShapeProto shape_proto;
|
||||
shape.AsProto(&shape_proto);
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out));
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
|
||||
c->set_output(0, out);
|
||||
return Status::OK();
|
||||
})
|
||||
@ -87,10 +85,8 @@ REGISTER_OP("InfeedDequeueTuple")
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
|
||||
for (int i = 0; i < shapes.size(); ++i) {
|
||||
TensorShapeProto shape_proto;
|
||||
shapes[i].AsProto(&shape_proto);
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out));
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &out));
|
||||
c->set_output(i, out);
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -15,11 +15,13 @@ limitations under the License.
|
||||
#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h"
|
||||
|
||||
#include <unordered_set>
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/kernels/immutable_constant_op.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
@ -500,7 +500,6 @@ cc_library(
|
||||
# Generates library per group of ops.
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = [
|
||||
"array_ops",
|
||||
"bitwise_ops",
|
||||
"candidate_sampling_ops",
|
||||
"control_flow_ops",
|
||||
@ -534,6 +533,13 @@ tf_gen_op_libs(
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = [
|
||||
"array_ops",
|
||||
],
|
||||
deps = [":protos_all_cc"],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = [
|
||||
"audio_ops",
|
||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/log_memory.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/gradients.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/gpu/process_state.h"
|
||||
#include "tensorflow/core/common_runtime/gpu_device_context.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_reference.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/log_memory.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/op_segment.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
@ -39,6 +40,10 @@ ShapeRefiner::ShapeRefiner(int graph_def_version,
|
||||
ops_registry_(ops),
|
||||
graph_runner_(Env::Default()) {}
|
||||
|
||||
ShapeRefiner::ShapeRefiner(const VersionDef& versions,
|
||||
const OpRegistryInterface* ops)
|
||||
: ShapeRefiner(versions.producer(), ops) {}
|
||||
|
||||
ShapeRefiner::~ShapeRefiner() {
|
||||
// The lifetime of the tensors are bound to the GraphRunner, so the tensors
|
||||
// should be deleted before it.
|
||||
|
@ -36,6 +36,10 @@ class GraphProperties;
|
||||
class ShapeRefiner {
|
||||
public:
|
||||
ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops);
|
||||
|
||||
// Same as ShapeRefiner(versions.producer(), ops)
|
||||
ShapeRefiner(const VersionDef& versions, const OpRegistryInterface* ops);
|
||||
|
||||
~ShapeRefiner();
|
||||
|
||||
// Performs validation of 'node' and runs 'node's shape function,
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/graph.pb_text.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/subgraph.h"
|
||||
|
@ -15,7 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||
#include "tensorflow/core/common_runtime/costmodel_manager.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_description.pb.h"
|
||||
#include "tensorflow/core/graph/costmodel.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/strings/scanner.h"
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#endif
|
||||
|
||||
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/summary.pb.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
|
@ -153,6 +153,7 @@ cc_library(
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
],
|
||||
)
|
||||
@ -205,6 +206,7 @@ cc_test(
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/log_memory.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_partition.h"
|
||||
|
@ -26,11 +26,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/distributed_runtime/scheduler.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_description.pb.h"
|
||||
#include "tensorflow/core/graph/graph_partition.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
|
@ -108,6 +108,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
],
|
||||
|
@ -18,7 +18,9 @@ limitations under the License.
|
||||
#include "grpc++/support/slice.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_reference.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/io/proto_encode_helper.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
|
||||
#include "google/protobuf/any.pb.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
|
||||
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
|
@ -17,9 +17,11 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/example/feature.pb_text.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
@ -48,6 +48,14 @@ constexpr size_t Allocator::kAllocatorAlignment;
|
||||
|
||||
Allocator::~Allocator() {}
|
||||
|
||||
void RunResourceCtor(ResourceHandle* p, size_t n) {
|
||||
for (size_t i = 0; i < n; ++p, ++i) new (p) ResourceHandle();
|
||||
}
|
||||
|
||||
void RunResourceDtor(ResourceHandle* p, size_t n) {
|
||||
for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
|
||||
}
|
||||
|
||||
// If true, cpu allocator collects more stats.
|
||||
static bool cpu_allocator_collect_stats = false;
|
||||
// If true, cpu allocator collects full stats.
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/attr_value.pb_text.h"
|
||||
#include "tensorflow/core/framework/tensor.pb_text.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb_text.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -287,6 +288,8 @@ bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
|
||||
return ProtoParseFromString(to_parse, out);
|
||||
}
|
||||
|
||||
void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; }
|
||||
|
||||
#define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
|
||||
void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); }
|
||||
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h" // TODO(62899350): Remove
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
@ -29,6 +29,10 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Forward declare protos so their symbols can be removed from .so exports
|
||||
class AttrValue;
|
||||
class NameAttrList;
|
||||
|
||||
// A human-readable rendering of attr_value, that is more concise than a
|
||||
// text-format proto.
|
||||
string SummarizeAttrValue(const AttrValue& attr_value);
|
||||
@ -80,9 +84,7 @@ void SetAttrValue(gtl::ArraySlice<Tensor> value, AttrValue* out);
|
||||
void SetAttrValue(gtl::ArraySlice<TensorProto> value, AttrValue* out);
|
||||
void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out);
|
||||
|
||||
inline void SetAttrValue(const AttrValue& value, AttrValue* out) {
|
||||
*out = value;
|
||||
}
|
||||
void SetAttrValue(const AttrValue& value, AttrValue* out);
|
||||
|
||||
// Returns true if a and b have the same value.
|
||||
// NOTE: May return false negatives for tensor values.
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
|
@ -26,19 +26,11 @@ namespace shape_inference {
|
||||
|
||||
namespace {
|
||||
|
||||
TensorShapeProto S(std::initializer_list<int64> dims) {
|
||||
PartialTensorShape shape(dims);
|
||||
TensorShapeProto ret;
|
||||
shape.AsProto(&ret);
|
||||
return ret;
|
||||
PartialTensorShape S(std::initializer_list<int64> dims) {
|
||||
return PartialTensorShape(dims);
|
||||
}
|
||||
|
||||
TensorShapeProto Unknown() {
|
||||
PartialTensorShape shape;
|
||||
TensorShapeProto ret;
|
||||
shape.AsProto(&ret);
|
||||
return ret;
|
||||
}
|
||||
PartialTensorShape Unknown() { return PartialTensorShape(); }
|
||||
|
||||
OpDef MakeOpDef(int num_inputs, int num_outputs) {
|
||||
OpRegistrationData op_reg_data;
|
||||
|
@ -19,4 +19,8 @@ namespace tensorflow {
|
||||
|
||||
DeviceBase::~DeviceBase() {}
|
||||
|
||||
const DeviceAttributes& DeviceBase::attributes() const {
|
||||
LOG(FATAL) << "Device does not implement attributes()";
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -19,9 +19,9 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -44,10 +44,12 @@ class Stream;
|
||||
namespace tensorflow {
|
||||
|
||||
class Device;
|
||||
class DeviceAttributes;
|
||||
class Env;
|
||||
class EventMgr;
|
||||
class OpKernelContext;
|
||||
class ResourceMgr;
|
||||
class TensorProto;
|
||||
|
||||
namespace thread {
|
||||
class ThreadPool;
|
||||
@ -194,11 +196,8 @@ class DeviceBase {
|
||||
DeviceContext* /*dc*/,
|
||||
Allocator* /*allocator*/) {}
|
||||
|
||||
virtual const DeviceAttributes& attributes() const {
|
||||
LOG(FATAL) << "Device does not implement attributes()";
|
||||
static DeviceAttributes dummy;
|
||||
return dummy;
|
||||
}
|
||||
// Unimplemented by default
|
||||
virtual const DeviceAttributes& attributes() const;
|
||||
|
||||
// Materializes the given TensorProto into 'tensor' stored in Device
|
||||
// memory. Most devices will want to override this.
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_def_util.h"
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/function.pb_text.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
@ -17,9 +17,10 @@ limitations under the License.
|
||||
#define TENSORFLOW_FRAMEWORK_FUNCTION_H_
|
||||
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/selective_registration.h"
|
||||
@ -33,6 +34,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
class CancellationManager;
|
||||
class GraphDef;
|
||||
class OpKernel;
|
||||
class ResourceMgr;
|
||||
class ScopedStepContainer;
|
||||
|
@ -20,7 +20,9 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op_def_util.h"
|
||||
|
@ -17,13 +17,15 @@ limitations under the License.
|
||||
#define TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
|
||||
|
||||
#include <set>
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Forward declare proto so that it's symbols can be removed from .so exports
|
||||
class GraphDef;
|
||||
|
||||
// Produce a human-readable version of a GraphDef that is more concise
|
||||
// than a text-format proto.
|
||||
string SummarizeGraphDef(const GraphDef& graph_def);
|
||||
|
@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb_text.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -24,6 +25,10 @@ KernelDefBuilder::KernelDefBuilder(const char* op_name) {
|
||||
kernel_def_->set_op(op_name);
|
||||
}
|
||||
|
||||
KernelDefBuilder::~KernelDefBuilder() {
|
||||
DCHECK(kernel_def_ == nullptr) << "Did not call Build()";
|
||||
}
|
||||
|
||||
KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) {
|
||||
kernel_def_->set_device_type(device_type);
|
||||
return *this;
|
||||
@ -61,4 +66,10 @@ KernelDefBuilder& KernelDefBuilder::Label(const char* label) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
const KernelDef* KernelDefBuilder::Build() {
|
||||
KernelDef* r = kernel_def_;
|
||||
kernel_def_ = nullptr;
|
||||
return r;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
|
||||
#define TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
|
||||
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
@ -24,16 +24,16 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Forward declare proto so that kernels don't need to depend on it
|
||||
class KernelDef;
|
||||
|
||||
// Builder class passed to the REGISTER_KERNEL_BUILDER() macro.
|
||||
class KernelDefBuilder {
|
||||
public:
|
||||
// Starts with just the name field set.
|
||||
// Caller MUST call Build() and take ownership of the result.
|
||||
explicit KernelDefBuilder(const char* op_name);
|
||||
|
||||
~KernelDefBuilder() {
|
||||
DCHECK(kernel_def_ == nullptr) << "Did not call Build()";
|
||||
}
|
||||
~KernelDefBuilder();
|
||||
|
||||
// Required: specify the type of device this kernel supports.
|
||||
// Returns *this.
|
||||
@ -68,11 +68,7 @@ class KernelDefBuilder {
|
||||
// Returns a pointer to a KernelDef with fields set based on the
|
||||
// above calls to this instance.
|
||||
// Caller takes ownership of the result.
|
||||
const KernelDef* Build() {
|
||||
KernelDef* r = kernel_def_;
|
||||
kernel_def_ = nullptr;
|
||||
return r;
|
||||
}
|
||||
const KernelDef* Build();
|
||||
|
||||
private:
|
||||
KernelDef* kernel_def_;
|
||||
|
@ -16,12 +16,14 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_
|
||||
#define TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class NodeDef;
|
||||
|
||||
// Returns into *{input,output}_memory_types the memory type of each
|
||||
// {input,output} tensor.
|
||||
//
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/op_def_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
@ -83,6 +84,25 @@ NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Input(StringPiece src_node, int src_index,
|
||||
DataType dt) {
|
||||
const OpDef::ArgDef* arg = NextArgDef();
|
||||
if (arg != nullptr) SingleInput(arg, src_node, src_index, dt);
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Input(const NodeOut& src) {
|
||||
Input(src.node, src.index, src.data_type);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// For inputs that take a list of tensors.
|
||||
NodeDefBuilder& NodeDefBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
|
||||
const OpDef::ArgDef* arg = NextArgDef();
|
||||
if (arg != nullptr) ListInput(arg, src_list);
|
||||
return *this;
|
||||
}
|
||||
|
||||
void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg,
|
||||
StringPiece src_node, int src_index,
|
||||
DataType dt) {
|
||||
@ -228,14 +248,51 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def) const {
|
||||
}
|
||||
}
|
||||
|
||||
void NodeDefBuilder::CheckInconsistency(StringPiece attr_name,
|
||||
const AttrValue& found,
|
||||
const AttrValue& attr_value) {
|
||||
if (!AreAttrValuesEqual(found, attr_value)) {
|
||||
errors_.push_back(strings::StrCat(
|
||||
"Inconsistent values for attr '", attr_name, "' ",
|
||||
SummarizeAttrValue(found), " vs. ", SummarizeAttrValue(attr_value)));
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) {
|
||||
if (const AttrValue* found = AttrSlice(node_def_).Find(name)) {
|
||||
if (!AreAttrValuesEqual(*found, value)) {
|
||||
errors_.push_back(strings::StrCat("Inconsistent values for attr '", name,
|
||||
"' ", SummarizeAttrValue(*found),
|
||||
" vs. ", SummarizeAttrValue(value)));
|
||||
}
|
||||
} else {
|
||||
AddNodeAttr(name, value, &node_def_);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
#define ATTR(T) \
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, T value) { \
|
||||
AttrValue attr_value; \
|
||||
SetAttrValue(value, &attr_value); \
|
||||
return Attr(name, attr_value); \
|
||||
}
|
||||
ATTR(StringPiece)
|
||||
ATTR(const char*)
|
||||
ATTR(int32)
|
||||
ATTR(int64)
|
||||
ATTR(float)
|
||||
ATTR(double)
|
||||
ATTR(bool)
|
||||
ATTR(DataType)
|
||||
ATTR(const PartialTensorShape&)
|
||||
ATTR(const Tensor&)
|
||||
ATTR(const TensorProto&)
|
||||
ATTR(const NameAttrList&)
|
||||
ATTR(gtl::ArraySlice<StringPiece>)
|
||||
ATTR(gtl::ArraySlice<const char*>)
|
||||
ATTR(gtl::ArraySlice<string>)
|
||||
ATTR(gtl::ArraySlice<int32>)
|
||||
ATTR(gtl::ArraySlice<int64>)
|
||||
ATTR(gtl::ArraySlice<float>)
|
||||
ATTR(gtl::ArraySlice<bool>)
|
||||
ATTR(const std::vector<bool>&)
|
||||
ATTR(gtl::ArraySlice<DataType>)
|
||||
ATTR(gtl::ArraySlice<TensorShape>)
|
||||
ATTR(gtl::ArraySlice<PartialTensorShape>)
|
||||
ATTR(gtl::ArraySlice<TensorShapeProto>)
|
||||
ATTR(gtl::ArraySlice<Tensor>)
|
||||
ATTR(gtl::ArraySlice<NameAttrList>)
|
||||
#undef ATTR
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
@ -72,22 +72,11 @@ class NodeDefBuilder {
|
||||
// *and in the same order as the input_args appear in the OpDef.*
|
||||
|
||||
// For inputs that take a single tensor.
|
||||
NodeDefBuilder& Input(StringPiece src_node, int src_index, DataType dt) {
|
||||
const OpDef::ArgDef* arg = NextArgDef();
|
||||
if (arg != nullptr) SingleInput(arg, src_node, src_index, dt);
|
||||
return *this;
|
||||
}
|
||||
NodeDefBuilder& Input(const NodeOut& src) {
|
||||
Input(src.node, src.index, src.data_type);
|
||||
return *this;
|
||||
}
|
||||
NodeDefBuilder& Input(StringPiece src_node, int src_index, DataType dt);
|
||||
NodeDefBuilder& Input(const NodeOut& src);
|
||||
|
||||
// For inputs that take a list of tensors.
|
||||
NodeDefBuilder& Input(gtl::ArraySlice<NodeOut> src_list) {
|
||||
const OpDef::ArgDef* arg = NextArgDef();
|
||||
if (arg != nullptr) ListInput(arg, src_list);
|
||||
return *this;
|
||||
}
|
||||
NodeDefBuilder& Input(gtl::ArraySlice<NodeOut> src_list);
|
||||
|
||||
// To create inputs in tests, see fake_input.h.
|
||||
NodeDefBuilder& Input(FakeInputFunctor fake_input);
|
||||
@ -100,13 +89,39 @@ class NodeDefBuilder {
|
||||
|
||||
// Sets the attr, if not already set. If already set with a different
|
||||
// value, an error will be returned from Finalize().
|
||||
NodeDefBuilder& Attr(StringPiece name, const AttrValue& value);
|
||||
NodeDefBuilder& Attr(StringPiece name, StringPiece value);
|
||||
NodeDefBuilder& Attr(StringPiece name, const char* value);
|
||||
NodeDefBuilder& Attr(StringPiece name, int32 value);
|
||||
NodeDefBuilder& Attr(StringPiece name, int64 value);
|
||||
NodeDefBuilder& Attr(StringPiece name, float value);
|
||||
NodeDefBuilder& Attr(StringPiece name, double value);
|
||||
NodeDefBuilder& Attr(StringPiece name, bool value);
|
||||
NodeDefBuilder& Attr(StringPiece name, DataType value);
|
||||
NodeDefBuilder& Attr(StringPiece name, const PartialTensorShape& value);
|
||||
NodeDefBuilder& Attr(StringPiece name, const Tensor& value);
|
||||
NodeDefBuilder& Attr(StringPiece name, const TensorProto& value);
|
||||
NodeDefBuilder& Attr(StringPiece name, const NameAttrList& value);
|
||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<StringPiece> value);
|
||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<const char*> value);
|
||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<string> value);
|
||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int32> value);
|
||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int64> value);
|
||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<float> value);
|
||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<bool> value);
|
||||
NodeDefBuilder& Attr(StringPiece name, const std::vector<bool>& value);
|
||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<DataType> value);
|
||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<TensorShape> value);
|
||||
NodeDefBuilder& Attr(StringPiece name,
|
||||
gtl::ArraySlice<PartialTensorShape> value);
|
||||
NodeDefBuilder& Attr(StringPiece name,
|
||||
gtl::ArraySlice<TensorShapeProto> value);
|
||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<Tensor> value);
|
||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<NameAttrList> value);
|
||||
|
||||
template <class T>
|
||||
NodeDefBuilder& Attr(StringPiece attr_name, T&& value);
|
||||
// Note: overload needed to allow {...} expressions for value.
|
||||
template <class T>
|
||||
NodeDefBuilder& Attr(StringPiece attr_name, std::initializer_list<T> value) {
|
||||
Attr<std::initializer_list<T>>(attr_name, std::move(value));
|
||||
return *this;
|
||||
NodeDefBuilder& Attr(StringPiece name, std::initializer_list<T> value) {
|
||||
return Attr(name, gtl::ArraySlice<T>(value));
|
||||
}
|
||||
|
||||
// Finish building the NodeDef, returning any errors or setting
|
||||
@ -152,9 +167,6 @@ class NodeDefBuilder {
|
||||
return input_arg->is_ref() ? MakeRefType(dt) : dt;
|
||||
}
|
||||
|
||||
void CheckInconsistency(StringPiece attr_name, const AttrValue& found,
|
||||
const AttrValue& attr_value);
|
||||
|
||||
const OpDef* op_def_;
|
||||
NodeDef node_def_;
|
||||
int inputs_specified_;
|
||||
@ -162,21 +174,6 @@ class NodeDefBuilder {
|
||||
std::vector<string> errors_;
|
||||
};
|
||||
|
||||
// IMPLEMENTATION -------------------------------------------------------------
|
||||
|
||||
template <class T>
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece attr_name, T&& value) {
|
||||
const AttrValue* found = AttrSlice(node_def_).Find(attr_name);
|
||||
if (found == nullptr) {
|
||||
AddNodeAttr(attr_name, std::forward<T>(value), &node_def_);
|
||||
} else {
|
||||
AttrValue attr_value;
|
||||
SetAttrValue(std::forward<T>(value), &attr_value);
|
||||
CheckInconsistency(attr_name, *found, attr_value);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_def.pb_text.h"
|
||||
#include "tensorflow/core/framework/op_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor.pb_text.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
@ -30,8 +30,9 @@ namespace tensorflow {
|
||||
|
||||
class Node;
|
||||
|
||||
// We forward declare NodeDef so that kernels don't need to depend on protos
|
||||
// We forward declare protos so that kernels don't need to depend on them
|
||||
class NodeDef;
|
||||
class OpDef;
|
||||
|
||||
// Name of the attribute used to encode node colocation constraints.
|
||||
//
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/op_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_def_util.h"
|
||||
#include "tensorflow/core/framework/selective_registration.h"
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/op_def_util.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/op_def.pb_text.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_gen_overrides.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
@ -71,6 +73,9 @@ bool ConsumeEquals(StringPiece* description) {
|
||||
return false;
|
||||
}
|
||||
|
||||
OpGenOverrideMap::OpGenOverrideMap() {}
|
||||
OpGenOverrideMap::~OpGenOverrideMap() {}
|
||||
|
||||
Status OpGenOverrideMap::LoadFileList(Env* env, const string& filenames) {
|
||||
std::vector<string> v = str_util::Split(filenames, ",");
|
||||
for (const string& f : v) {
|
||||
@ -86,7 +91,7 @@ Status OpGenOverrideMap::LoadFile(Env* env, const string& filename) {
|
||||
OpGenOverrides all;
|
||||
protobuf::TextFormat::ParseFromString(contents, &all);
|
||||
for (const auto& one : all.op()) {
|
||||
map_[one.name()] = one;
|
||||
map_[one.name()].reset(new OpGenOverride(one));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -142,7 +147,7 @@ const OpGenOverride* OpGenOverrideMap::ApplyOverride(OpDef* op_def) const {
|
||||
// Look up
|
||||
const auto iter = map_.find(op_def->name());
|
||||
if (iter == map_.end()) return nullptr;
|
||||
const OpGenOverride& proto = iter->second;
|
||||
const OpGenOverride& proto = *iter->second;
|
||||
|
||||
// Apply overrides from `proto`.
|
||||
if (!proto.rename_to().empty()) {
|
||||
|
@ -18,14 +18,18 @@ limitations under the License.
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_gen_overrides.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/op_gen_overrides.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Forward declare protos so their symbols can be removed from .so exports
|
||||
class OpDef;
|
||||
class OpGenOverride;
|
||||
|
||||
inline string Spaces(int n) { return string(n, ' '); }
|
||||
|
||||
// Wrap prefix + str to be at most width characters, indenting every line
|
||||
@ -43,6 +47,9 @@ bool ConsumeEquals(StringPiece* description);
|
||||
// look up the specific override for any given op.
|
||||
class OpGenOverrideMap {
|
||||
public:
|
||||
OpGenOverrideMap();
|
||||
~OpGenOverrideMap();
|
||||
|
||||
// `filenames` is a comma-separated list of file names. If an op
|
||||
// is mentioned in more than one file, the last one takes priority.
|
||||
Status LoadFileList(Env* env, const string& filenames);
|
||||
@ -61,7 +68,7 @@ class OpGenOverrideMap {
|
||||
const OpGenOverride* ApplyOverride(OpDef* op_def) const;
|
||||
|
||||
private:
|
||||
std::unordered_map<string, OpGenOverride> map_;
|
||||
std::unordered_map<string, std::unique_ptr<OpGenOverride>> map_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb_text.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb_text.h"
|
||||
#include "tensorflow/core/framework/log_memory.h"
|
||||
|
@ -24,19 +24,19 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/control_flow.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/function.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/selective_registration.h"
|
||||
#include "tensorflow/core/framework/session_state.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/tracking_allocator.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
@ -65,9 +65,13 @@ class TensorSliceReaderCacheWrapper;
|
||||
} // namespace checkpoint
|
||||
|
||||
class AsyncOpKernel;
|
||||
class FunctionCallFrame;
|
||||
class FunctionLibraryRuntime;
|
||||
class OpKernelConstruction; // declared below
|
||||
class OpKernelContext; // declared below
|
||||
class OpRegistryInterface;
|
||||
class ResourceMgr;
|
||||
class ScopedStepContainer;
|
||||
|
||||
class OpKernel {
|
||||
public:
|
||||
|
@ -19,10 +19,12 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/reader_base.h"
|
||||
|
||||
#include "tensorflow/core/framework/reader_base.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/coding.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
@ -19,12 +19,14 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "tensorflow/core/framework/queue_interface.h"
|
||||
#include "tensorflow/core/framework/reader_base.pb.h"
|
||||
#include "tensorflow/core/framework/reader_base.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/reader_interface.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class ReaderBaseState;
|
||||
|
||||
// Default implementation of ReaderInterface.
|
||||
class ReaderBase : public ReaderInterface {
|
||||
public:
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
@ -22,8 +22,9 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_handle.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/node_def.pb_text.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
@ -79,6 +80,58 @@ InferenceContext::InferenceContext(
|
||||
PostInputInit(std::move(handle_data));
|
||||
}
|
||||
|
||||
// Same as above, but with PartialTensorShape instead of TensorShapeProto
|
||||
InferenceContext::InferenceContext(
|
||||
int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
|
||||
const std::vector<PartialTensorShape>& input_shapes,
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<PartialTensorShape>& input_tensors_as_shapes,
|
||||
const std::vector<
|
||||
std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>&
|
||||
input_handle_shapes_and_types)
|
||||
: graph_def_version_(graph_def_version),
|
||||
node_def_(*CHECK_NOTNULL(node_def)) {
|
||||
std::vector<ShapeHandle> input_tensors_as_shape_handles;
|
||||
for (const PartialTensorShape& p : input_tensors_as_shapes) {
|
||||
ShapeHandle shape;
|
||||
construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
|
||||
if (!construction_status_.ok()) {
|
||||
return;
|
||||
}
|
||||
input_tensors_as_shape_handles.push_back(shape);
|
||||
}
|
||||
PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
|
||||
if (!construction_status_.ok()) return;
|
||||
for (const PartialTensorShape& p : input_shapes) {
|
||||
ShapeHandle shape;
|
||||
construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
|
||||
if (!construction_status_.ok()) {
|
||||
return;
|
||||
}
|
||||
inputs_.push_back(shape);
|
||||
}
|
||||
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
|
||||
input_shapes.size());
|
||||
for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) {
|
||||
const auto& v = input_handle_shapes_and_types[i];
|
||||
if (v == nullptr) {
|
||||
continue;
|
||||
}
|
||||
handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
|
||||
auto& new_v = *handle_data[i];
|
||||
for (int j = 0; j < v->size(); ++j) {
|
||||
const auto& p = (*v)[j];
|
||||
construction_status_.Update(
|
||||
MakeShapeFromPartialTensorShape(p.first, &new_v[j].shape));
|
||||
if (!construction_status_.ok()) {
|
||||
return;
|
||||
}
|
||||
new_v[j].dtype = p.second;
|
||||
}
|
||||
}
|
||||
PostInputInit(std::move(handle_data));
|
||||
}
|
||||
|
||||
InferenceContext::InferenceContext(
|
||||
int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
|
||||
const std::vector<ShapeHandle>& input_shapes,
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -189,6 +189,26 @@ class InferenceContext {
|
||||
std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>&
|
||||
input_handle_shapes_and_types);
|
||||
|
||||
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
|
||||
//
|
||||
// Elements of <input_tensors_as_shapes> are used for when a shape
|
||||
// function makes a call to MakeShapeFromShapeTensor; in particular, when
|
||||
// the input_tensors[i] is nullptr but the shape represented by it is
|
||||
// partially known from analysis of the graph. <input_tensors_as_shapes>
|
||||
// can have fewer elements than <input_shapes>. Values of
|
||||
// <input_tensors_as_shapes> do not need to outlive the context.
|
||||
//
|
||||
// REQUIRES: <node_def> is not NULL, and must outlive the
|
||||
// InferenceContext.
|
||||
InferenceContext(
|
||||
int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
|
||||
const std::vector<PartialTensorShape>& input_shapes,
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<PartialTensorShape>& input_tensors_as_shapes,
|
||||
const std::vector<std::unique_ptr<
|
||||
std::vector<std::pair<PartialTensorShape, DataType>>>>&
|
||||
input_handle_shapes_and_types);
|
||||
|
||||
~InferenceContext();
|
||||
|
||||
// Runs the shape inference function 'fn' with 'this' as the
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -36,19 +37,11 @@ OpDef MakeOpDefWithLists() {
|
||||
return op_reg_data.op_def;
|
||||
}
|
||||
|
||||
TensorShapeProto S(std::initializer_list<int64> dims) {
|
||||
PartialTensorShape shape(dims);
|
||||
TensorShapeProto ret;
|
||||
shape.AsProto(&ret);
|
||||
return ret;
|
||||
PartialTensorShape S(std::initializer_list<int64> dims) {
|
||||
return PartialTensorShape(dims);
|
||||
}
|
||||
|
||||
TensorShapeProto Unknown() {
|
||||
PartialTensorShape shape;
|
||||
TensorShapeProto ret;
|
||||
shape.AsProto(&ret);
|
||||
return ret;
|
||||
}
|
||||
PartialTensorShape Unknown() { return PartialTensorShape(); }
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -1537,7 +1530,7 @@ void ShapeInferenceTest::TestMergeHandles(bool input_not_output) {
|
||||
{});
|
||||
auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
|
||||
ShapeHandle s;
|
||||
TF_CHECK_OK(c.MakeShapeFromShapeProto(S(dim_sizes), &s));
|
||||
TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
|
||||
return s;
|
||||
};
|
||||
auto get_shapes_and_types_from_context = [&](int idx) {
|
||||
@ -1648,7 +1641,7 @@ void ShapeInferenceTest::TestRelaxHandles(bool input_not_output) {
|
||||
{});
|
||||
auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
|
||||
ShapeHandle s;
|
||||
TF_CHECK_OK(c.MakeShapeFromShapeProto(S(dim_sizes), &s));
|
||||
TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
|
||||
return s;
|
||||
};
|
||||
auto get_shapes_and_types_from_context = [&](int idx) {
|
||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -28,7 +27,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class NodeDef;
|
||||
class Tensor;
|
||||
|
||||
struct ShapeInferenceTestOp {
|
||||
|
@ -29,9 +29,11 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/log_memory.h"
|
||||
#include "tensorflow/core/framework/resource_handle.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_description.pb.h"
|
||||
#include "tensorflow/core/framework/type_traits.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/coding.h"
|
||||
|
@ -17,10 +17,10 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_description.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/tensor_description.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
@ -35,8 +35,13 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorBuffer; // Forward declaration.
|
||||
// Forward declarations. In particular, we forward declare protos so that their
|
||||
// symbols can be removed from .so exports.
|
||||
class AllocationDescription;
|
||||
class TensorBuffer;
|
||||
class TensorCApi;
|
||||
class TensorDescription;
|
||||
class TensorProto;
|
||||
|
||||
/// @ingroup core
|
||||
/// Represents an n-dimensional array of values.
|
||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_FRAMEWORK_TENSOR_REFERENCE_H_
|
||||
#define TENSORFLOW_FRAMEWORK_TENSOR_REFERENCE_H_
|
||||
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -35,6 +35,7 @@ namespace tensorflow {
|
||||
template <class Shape>
|
||||
class TensorShapeIter;
|
||||
class TensorShape;
|
||||
class TensorShapeProto;
|
||||
class PartialTensorShape;
|
||||
// END_SKIP_DOXYGEN
|
||||
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/versions.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
|
@ -16,11 +16,13 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_FRAMEWORK_VERSIONS_H_
|
||||
#define TENSORFLOW_FRAMEWORK_VERSIONS_H_
|
||||
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class VersionDef;
|
||||
|
||||
// Check whether data with the given versions is compatible with the given
|
||||
// consumer and min producer. upper_name and lower_name are used to form
|
||||
// error messages upon failure. Example usage:
|
||||
|
@ -16,8 +16,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/costmodel.h"
|
||||
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_description.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
|
@ -16,9 +16,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
@ -255,9 +257,11 @@ Status Node::input_node(int idx, const Node** const_n) const {
|
||||
// Graph
|
||||
|
||||
Graph::Graph(const OpRegistryInterface* ops)
|
||||
: ops_(ops, FunctionDefLibrary()), arena_(8 << 10 /* 8kB */) {
|
||||
versions_.set_producer(TF_GRAPH_DEF_VERSION);
|
||||
versions_.set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
|
||||
: ops_(ops, FunctionDefLibrary()),
|
||||
versions_(new VersionDef),
|
||||
arena_(8 << 10 /* 8kB */) {
|
||||
versions_->set_producer(TF_GRAPH_DEF_VERSION);
|
||||
versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
|
||||
|
||||
// Initialize the name interning table for assigned_device_name.
|
||||
device_names_.push_back("");
|
||||
@ -301,6 +305,9 @@ Graph::~Graph() {
|
||||
// destroy them.
|
||||
}
|
||||
|
||||
const VersionDef& Graph::versions() const { return *versions_; }
|
||||
void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
|
||||
|
||||
Node* Graph::AddNode(const NodeDef& node_def, Status* status) {
|
||||
const OpDef* op_def;
|
||||
status->Update(ops_.LookUpOpDef(node_def.op(), &op_def));
|
||||
|
@ -41,10 +41,10 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/graph/edgeset.h"
|
||||
#include "tensorflow/core/lib/core/arena.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
@ -59,7 +59,9 @@ namespace tensorflow {
|
||||
class Edge;
|
||||
class EdgeSetTest;
|
||||
class Graph;
|
||||
class GraphDef;
|
||||
class Node;
|
||||
class VersionDef;
|
||||
|
||||
class NeighborIter; // Declared below
|
||||
class NodeIter; // Declared below
|
||||
@ -370,8 +372,8 @@ class Graph {
|
||||
static const int kControlSlot;
|
||||
|
||||
// The GraphDef version range of this graph (see graph.proto).
|
||||
const VersionDef& versions() const { return versions_; }
|
||||
void set_versions(const VersionDef& versions) { versions_ = versions; }
|
||||
const VersionDef& versions() const;
|
||||
void set_versions(const VersionDef& versions);
|
||||
|
||||
// Adds a new node to this graph, and returns it. Infers the Op and
|
||||
// input/output types for the node. *this owns the returned instance.
|
||||
@ -514,7 +516,7 @@ class Graph {
|
||||
FunctionLibraryDefinition ops_;
|
||||
|
||||
// GraphDef versions
|
||||
VersionDef versions_;
|
||||
const std::unique_ptr<VersionDef> versions_;
|
||||
|
||||
// Allocator which will give us good locality.
|
||||
core::Arena arena_;
|
||||
|
@ -27,8 +27,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/versions.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
@ -22,7 +22,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/memory_types.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
#include "tensorflow/core/graph/costmodel.h"
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
|
||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
|
||||
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
|
||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
@ -60,6 +60,7 @@ cc_test(
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
@ -196,6 +197,7 @@ cc_test(
|
||||
":virtual_placer",
|
||||
":virtual_scheduler",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
@ -232,6 +234,7 @@ cc_library(
|
||||
":cost_estimator",
|
||||
":op_performance_data_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler/clusters:utils",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
@ -265,6 +268,7 @@ cc_library(
|
||||
":virtual_scheduler",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
],
|
||||
)
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <limits>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/graph/types.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
|
||||
|
@ -179,7 +179,7 @@ Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
|
||||
|
||||
Status GraphProperties::InferStatically() {
|
||||
Graph graph(OpRegistry::Global());
|
||||
ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry());
|
||||
ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
|
||||
shape_refiner.set_require_shape_inference_fns(false);
|
||||
ImportGraphDefOptions options;
|
||||
Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner);
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user