Prepare to remove a bunch of proto.h includes from tensorflow/core headers

The goal is to make kernels mostly independent of proto headers, which will let
us lock down our .so imports.  This CL does not remove any actual headers, but
changes a bunch of files so that header removal is possible in a followup CL.
It also marks the headers that will be removed with

    // TODO(b/62899350): Remove

RELNOTES: n/a
PiperOrigin-RevId: 160552878
This commit is contained in:
Geoffrey Irving 2017-06-29 11:44:13 -07:00 committed by TensorFlower Gardener
parent 9b11f45819
commit e85d3df92d
150 changed files with 555 additions and 240 deletions

View File

@ -62,6 +62,7 @@ tf_cuda_library(
"//tensorflow/cc:scope_internal",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:lib",
],
}),

View File

@ -28,12 +28,15 @@ limitations under the License.
#endif
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
@ -1587,6 +1590,14 @@ void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
// TF_Graph functions ---------------------------------------------------------
TF_Graph::TF_Graph()
: graph(tensorflow::OpRegistry::Global()),
refiner(graph.versions().producer(), graph.op_registry()),
num_sessions(0),
delete_requested(false),
parent(nullptr),
parent_inputs(nullptr) {}
TF_Graph* TF_NewGraph() { return new TF_Graph; }
void TF_DeleteGraph(TF_Graph* g) {

View File

@ -56,13 +56,8 @@ struct TF_Library {
};
struct TF_Graph {
TF_Graph()
: graph(tensorflow::OpRegistry::Global()),
refiner(graph.versions().producer(), graph.op_registry()),
num_sessions(0),
delete_requested(false),
parent(nullptr),
parent_inputs(nullptr) {}
TF_Graph();
tensorflow::mutex mu;
tensorflow::Graph graph GUARDED_BY(mu);

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/tensor_id.h"

View File

@ -45,6 +45,7 @@ tf_cc_test(
"//tensorflow/core:all_kernels",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
@ -432,6 +433,7 @@ cc_library_with_android_deps(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:op_gen_lib",
"//tensorflow/core:op_gen_overrides_proto_cc",
"//tensorflow/core:proto_text",
"//tensorflow/core:protos_all_cc",
],

View File

@ -18,8 +18,12 @@ limitations under the License.
#include <vector>
#include "tensorflow/cc/framework/cc_op_gen.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/framework/op_gen_overrides.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb_text.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/testutil.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"

View File

@ -136,7 +136,7 @@ Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
Scope Scope::NewRootScope() {
Graph* graph = new Graph(OpRegistry::Global());
ShapeRefiner* refiner =
new ShapeRefiner(graph->versions().producer(), graph->op_registry());
new ShapeRefiner(graph->versions(), graph->op_registry());
return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner));
}

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"

View File

@ -139,6 +139,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:tensorflow_opensource",
"//tensorflow/core/kernels:constant_op",

View File

@ -31,9 +31,11 @@ limitations under the License.
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/notification.h"

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/tensor.pb.h"
namespace tensorflow {
namespace {

View File

@ -23,6 +23,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"

View File

@ -54,8 +54,8 @@ class SingleImageRandomDotStereogramsOp : public OpKernel {
float normalize_min;
float border_level;
int number_colors;
::tensorflow::TensorShapeProto output_image_shape;
::tensorflow::TensorShapeProto output_data_window;
::tensorflow::PartialTensorShape output_image_shape;
::tensorflow::PartialTensorShape output_data_window;
uint8 Cblack = 0;
uint8 Cwhite = 255;
@ -109,15 +109,15 @@ class SingleImageRandomDotStereogramsOp : public OpKernel {
input_Yvalue =
input_tensor.shape().dim_size(0); // Y value is the number of rows
output_Ximage = output_image_shape.dim(0).size();
output_Yimage = output_image_shape.dim(1).size();
output_Cimage = output_image_shape.dim(2).size();
output_Ximage = output_image_shape.dim_size(0);
output_Yimage = output_image_shape.dim_size(1);
output_Cimage = output_image_shape.dim_size(2);
if (number_colors > 256) // Go to full color image
output_Cimage = 3;
int data_Xwindow = output_data_window.dim(0).size();
int data_Ywindow = output_data_window.dim(1).size();
int data_Xwindow = output_data_window.dim_size(0);
int data_Ywindow = output_data_window.dim_size(1);
int deltaX_border_image = output_Ximage - data_Xwindow;
int deltaY_border_image = output_Yimage - data_Ywindow;

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/util/work_sharder.h"

View File

@ -29,10 +29,8 @@ REGISTER_OP("InfeedDequeue")
.SetShapeFn([](InferenceContext* c) {
PartialTensorShape shape;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
TensorShapeProto shape_proto;
shape.AsProto(&shape_proto);
ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out));
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
c->set_output(0, out);
return Status::OK();
})
@ -87,10 +85,8 @@ REGISTER_OP("InfeedDequeueTuple")
std::vector<PartialTensorShape> shapes;
TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
for (int i = 0; i < shapes.size(); ++i) {
TensorShapeProto shape_proto;
shapes[i].AsProto(&shape_proto);
ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out));
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &out));
c->set_output(i, out);
}
return Status::OK();

View File

@ -15,11 +15,13 @@ limitations under the License.
#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h"
#include <unordered_set>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/immutable_constant_op.h"
#include "tensorflow/core/platform/env.h"

View File

@ -500,7 +500,6 @@ cc_library(
# Generates library per group of ops.
tf_gen_op_libs(
op_lib_names = [
"array_ops",
"bitwise_ops",
"candidate_sampling_ops",
"control_flow_ops",
@ -534,6 +533,13 @@ tf_gen_op_libs(
],
)
tf_gen_op_libs(
op_lib_names = [
"array_ops",
],
deps = [":protos_all_cc"],
)
tf_gen_op_libs(
op_lib_names = [
"audio_ops",

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/gradients.h"
#include "tensorflow/core/graph/graph_constructor.h"

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/process_state.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_segment.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/notification.h"

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
@ -39,6 +40,10 @@ ShapeRefiner::ShapeRefiner(int graph_def_version,
ops_registry_(ops),
graph_runner_(Env::Default()) {}
ShapeRefiner::ShapeRefiner(const VersionDef& versions,
const OpRegistryInterface* ops)
: ShapeRefiner(versions.producer(), ops) {}
ShapeRefiner::~ShapeRefiner() {
// The lifetime of the tensors are bound to the GraphRunner, so the tensors
// should be deleted before it.

View File

@ -36,6 +36,10 @@ class GraphProperties;
class ShapeRefiner {
public:
ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops);
// Same as ShapeRefiner(versions.producer(), ops)
ShapeRefiner(const VersionDef& versions, const OpRegistryInterface* ops);
~ShapeRefiner();
// Performs validation of 'node' and runs 'node's shape function,

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb_text.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/subgraph.h"

View File

@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/common_runtime/costmodel_manager.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/graph/costmodel.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/scanner.h"

View File

@ -27,6 +27,7 @@ limitations under the License.
#endif
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/io/path.h"

View File

@ -153,6 +153,7 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:worker_proto_cc",
],
)
@ -205,6 +206,7 @@ cc_test(
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/simple_philox.h"

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_partition.h"

View File

@ -26,11 +26,13 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/scheduler.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/graph/graph_partition.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/blocking_counter.h"

View File

@ -108,6 +108,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:worker_proto_cc",
"@grpc//:grpc++_unsecure",
],

View File

@ -18,7 +18,9 @@ limitations under the License.
#include "grpc++/support/slice.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/io/proto_encode_helper.h"
#include "tensorflow/core/platform/env.h"

View File

@ -17,6 +17,8 @@ limitations under the License.
#include "google/protobuf/any.pb.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
namespace tensorflow {

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"

View File

@ -17,9 +17,11 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/example/feature.pb_text.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"

View File

@ -48,6 +48,14 @@ constexpr size_t Allocator::kAllocatorAlignment;
Allocator::~Allocator() {}
void RunResourceCtor(ResourceHandle* p, size_t n) {
for (size_t i = 0; i < n; ++p, ++i) new (p) ResourceHandle();
}
void RunResourceDtor(ResourceHandle* p, size_t n) {
for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
}
// If true, cpu allocator collects more stats.
static bool cpu_allocator_collect_stats = false;
// If true, cpu allocator collects full stats.

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/attr_value.pb_text.h"
#include "tensorflow/core/framework/tensor.pb_text.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb_text.h"
#include "tensorflow/core/lib/core/errors.h"
@ -287,6 +288,8 @@ bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
return ProtoParseFromString(to_parse, out);
}
void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; }
#define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); }

View File

@ -18,7 +18,7 @@ limitations under the License.
#include <string>
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h" // TODO(62899350): Remove
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@ -29,6 +29,10 @@ limitations under the License.
namespace tensorflow {
// Forward declare protos so their symbols can be removed from .so exports
class AttrValue;
class NameAttrList;
// A human-readable rendering of attr_value, that is more concise than a
// text-format proto.
string SummarizeAttrValue(const AttrValue& attr_value);
@ -80,9 +84,7 @@ void SetAttrValue(gtl::ArraySlice<Tensor> value, AttrValue* out);
void SetAttrValue(gtl::ArraySlice<TensorProto> value, AttrValue* out);
void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out);
inline void SetAttrValue(const AttrValue& value, AttrValue* out) {
*out = value;
}
void SetAttrValue(const AttrValue& value, AttrValue* out);
// Returns true if a and b have the same value.
// NOTE: May return false negatives for tensor values.

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value_util.h"
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"

View File

@ -26,19 +26,11 @@ namespace shape_inference {
namespace {
TensorShapeProto S(std::initializer_list<int64> dims) {
PartialTensorShape shape(dims);
TensorShapeProto ret;
shape.AsProto(&ret);
return ret;
PartialTensorShape S(std::initializer_list<int64> dims) {
return PartialTensorShape(dims);
}
TensorShapeProto Unknown() {
PartialTensorShape shape;
TensorShapeProto ret;
shape.AsProto(&ret);
return ret;
}
PartialTensorShape Unknown() { return PartialTensorShape(); }
OpDef MakeOpDef(int num_inputs, int num_outputs) {
OpRegistrationData op_reg_data;

View File

@ -19,4 +19,8 @@ namespace tensorflow {
DeviceBase::~DeviceBase() {}
const DeviceAttributes& DeviceBase::attributes() const {
LOG(FATAL) << "Device does not implement attributes()";
}
} // namespace tensorflow

View File

@ -19,9 +19,9 @@ limitations under the License.
#include <memory>
#include <unordered_map>
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/device_attributes.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
@ -44,10 +44,12 @@ class Stream;
namespace tensorflow {
class Device;
class DeviceAttributes;
class Env;
class EventMgr;
class OpKernelContext;
class ResourceMgr;
class TensorProto;
namespace thread {
class ThreadPool;
@ -194,11 +196,8 @@ class DeviceBase {
DeviceContext* /*dc*/,
Allocator* /*allocator*/) {}
virtual const DeviceAttributes& attributes() const {
LOG(FATAL) << "Device does not implement attributes()";
static DeviceAttributes dummy;
return dummy;
}
// Unimplemented by default
virtual const DeviceAttributes& attributes() const;
// Materializes the given TensorProto into 'tensor' stored in Device
// memory. Most devices will want to override this.

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/fake_input.h"
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_def_util.h"

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/function.pb_text.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"

View File

@ -17,9 +17,10 @@ limitations under the License.
#define TENSORFLOW_FRAMEWORK_FUNCTION_H_
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/selective_registration.h"
@ -33,6 +34,7 @@ limitations under the License.
namespace tensorflow {
class CancellationManager;
class GraphDef;
class OpKernel;
class ResourceMgr;
class ScopedStepContainer;

View File

@ -20,7 +20,9 @@ limitations under the License.
#include <unordered_set>
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def_util.h"

View File

@ -17,13 +17,15 @@ limitations under the License.
#define TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
#include <set>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// Forward declare proto so that it's symbols can be removed from .so exports
class GraphDef;
// Produce a human-readable version of a GraphDef that is more concise
// than a text-format proto.
string SummarizeGraphDef(const GraphDef& graph_def);

View File

@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/kernel_def.pb_text.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
namespace tensorflow {
@ -24,6 +25,10 @@ KernelDefBuilder::KernelDefBuilder(const char* op_name) {
kernel_def_->set_op(op_name);
}
KernelDefBuilder::~KernelDefBuilder() {
DCHECK(kernel_def_ == nullptr) << "Did not call Build()";
}
KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) {
kernel_def_->set_device_type(device_type);
return *this;
@ -61,4 +66,10 @@ KernelDefBuilder& KernelDefBuilder::Label(const char* label) {
return *this;
}
const KernelDef* KernelDefBuilder::Build() {
KernelDef* r = kernel_def_;
kernel_def_ = nullptr;
return r;
}
} // namespace tensorflow

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
#define TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/kernel_def.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
@ -24,16 +24,16 @@ limitations under the License.
namespace tensorflow {
// Forward declare proto so that kernels don't need to depend on it
class KernelDef;
// Builder class passed to the REGISTER_KERNEL_BUILDER() macro.
class KernelDefBuilder {
public:
// Starts with just the name field set.
// Caller MUST call Build() and take ownership of the result.
explicit KernelDefBuilder(const char* op_name);
~KernelDefBuilder() {
DCHECK(kernel_def_ == nullptr) << "Did not call Build()";
}
~KernelDefBuilder();
// Required: specify the type of device this kernel supports.
// Returns *this.
@ -68,11 +68,7 @@ class KernelDefBuilder {
// Returns a pointer to a KernelDef with fields set based on the
// above calls to this instance.
// Caller takes ownership of the result.
const KernelDef* Build() {
KernelDef* r = kernel_def_;
kernel_def_ = nullptr;
return r;
}
const KernelDef* Build();
private:
KernelDef* kernel_def_;

View File

@ -16,12 +16,14 @@ limitations under the License.
#ifndef TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_
#define TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/types.h"
namespace tensorflow {
class NodeDef;
// Returns into *{input,output}_memory_types the memory type of each
// {input,output} tensor.
//

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_builder.h"
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
@ -83,6 +84,25 @@ NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) {
return *this;
}
NodeDefBuilder& NodeDefBuilder::Input(StringPiece src_node, int src_index,
DataType dt) {
const OpDef::ArgDef* arg = NextArgDef();
if (arg != nullptr) SingleInput(arg, src_node, src_index, dt);
return *this;
}
NodeDefBuilder& NodeDefBuilder::Input(const NodeOut& src) {
Input(src.node, src.index, src.data_type);
return *this;
}
// For inputs that take a list of tensors.
NodeDefBuilder& NodeDefBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
const OpDef::ArgDef* arg = NextArgDef();
if (arg != nullptr) ListInput(arg, src_list);
return *this;
}
void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg,
StringPiece src_node, int src_index,
DataType dt) {
@ -228,14 +248,51 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def) const {
}
}
void NodeDefBuilder::CheckInconsistency(StringPiece attr_name,
const AttrValue& found,
const AttrValue& attr_value) {
if (!AreAttrValuesEqual(found, attr_value)) {
errors_.push_back(strings::StrCat(
"Inconsistent values for attr '", attr_name, "' ",
SummarizeAttrValue(found), " vs. ", SummarizeAttrValue(attr_value)));
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) {
if (const AttrValue* found = AttrSlice(node_def_).Find(name)) {
if (!AreAttrValuesEqual(*found, value)) {
errors_.push_back(strings::StrCat("Inconsistent values for attr '", name,
"' ", SummarizeAttrValue(*found),
" vs. ", SummarizeAttrValue(value)));
}
} else {
AddNodeAttr(name, value, &node_def_);
}
return *this;
}
#define ATTR(T) \
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, T value) { \
AttrValue attr_value; \
SetAttrValue(value, &attr_value); \
return Attr(name, attr_value); \
}
ATTR(StringPiece)
ATTR(const char*)
ATTR(int32)
ATTR(int64)
ATTR(float)
ATTR(double)
ATTR(bool)
ATTR(DataType)
ATTR(const PartialTensorShape&)
ATTR(const Tensor&)
ATTR(const TensorProto&)
ATTR(const NameAttrList&)
ATTR(gtl::ArraySlice<StringPiece>)
ATTR(gtl::ArraySlice<const char*>)
ATTR(gtl::ArraySlice<string>)
ATTR(gtl::ArraySlice<int32>)
ATTR(gtl::ArraySlice<int64>)
ATTR(gtl::ArraySlice<float>)
ATTR(gtl::ArraySlice<bool>)
ATTR(const std::vector<bool>&)
ATTR(gtl::ArraySlice<DataType>)
ATTR(gtl::ArraySlice<TensorShape>)
ATTR(gtl::ArraySlice<PartialTensorShape>)
ATTR(gtl::ArraySlice<TensorShapeProto>)
ATTR(gtl::ArraySlice<Tensor>)
ATTR(gtl::ArraySlice<NameAttrList>)
#undef ATTR
} // namespace tensorflow

View File

@ -19,7 +19,7 @@ limitations under the License.
#include <functional>
#include <vector>
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
@ -72,22 +72,11 @@ class NodeDefBuilder {
// *and in the same order as the input_args appear in the OpDef.*
// For inputs that take a single tensor.
NodeDefBuilder& Input(StringPiece src_node, int src_index, DataType dt) {
const OpDef::ArgDef* arg = NextArgDef();
if (arg != nullptr) SingleInput(arg, src_node, src_index, dt);
return *this;
}
NodeDefBuilder& Input(const NodeOut& src) {
Input(src.node, src.index, src.data_type);
return *this;
}
NodeDefBuilder& Input(StringPiece src_node, int src_index, DataType dt);
NodeDefBuilder& Input(const NodeOut& src);
// For inputs that take a list of tensors.
NodeDefBuilder& Input(gtl::ArraySlice<NodeOut> src_list) {
const OpDef::ArgDef* arg = NextArgDef();
if (arg != nullptr) ListInput(arg, src_list);
return *this;
}
NodeDefBuilder& Input(gtl::ArraySlice<NodeOut> src_list);
// To create inputs in tests, see fake_input.h.
NodeDefBuilder& Input(FakeInputFunctor fake_input);
@ -100,13 +89,39 @@ class NodeDefBuilder {
// Sets the attr, if not already set. If already set with a different
// value, an error will be returned from Finalize().
NodeDefBuilder& Attr(StringPiece name, const AttrValue& value);
NodeDefBuilder& Attr(StringPiece name, StringPiece value);
NodeDefBuilder& Attr(StringPiece name, const char* value);
NodeDefBuilder& Attr(StringPiece name, int32 value);
NodeDefBuilder& Attr(StringPiece name, int64 value);
NodeDefBuilder& Attr(StringPiece name, float value);
NodeDefBuilder& Attr(StringPiece name, double value);
NodeDefBuilder& Attr(StringPiece name, bool value);
NodeDefBuilder& Attr(StringPiece name, DataType value);
NodeDefBuilder& Attr(StringPiece name, const PartialTensorShape& value);
NodeDefBuilder& Attr(StringPiece name, const Tensor& value);
NodeDefBuilder& Attr(StringPiece name, const TensorProto& value);
NodeDefBuilder& Attr(StringPiece name, const NameAttrList& value);
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<StringPiece> value);
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<const char*> value);
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<string> value);
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int32> value);
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int64> value);
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<float> value);
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<bool> value);
NodeDefBuilder& Attr(StringPiece name, const std::vector<bool>& value);
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<DataType> value);
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<TensorShape> value);
NodeDefBuilder& Attr(StringPiece name,
gtl::ArraySlice<PartialTensorShape> value);
NodeDefBuilder& Attr(StringPiece name,
gtl::ArraySlice<TensorShapeProto> value);
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<Tensor> value);
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<NameAttrList> value);
template <class T>
NodeDefBuilder& Attr(StringPiece attr_name, T&& value);
// Note: overload needed to allow {...} expressions for value.
template <class T>
NodeDefBuilder& Attr(StringPiece attr_name, std::initializer_list<T> value) {
Attr<std::initializer_list<T>>(attr_name, std::move(value));
return *this;
NodeDefBuilder& Attr(StringPiece name, std::initializer_list<T> value) {
return Attr(name, gtl::ArraySlice<T>(value));
}
// Finish building the NodeDef, returning any errors or setting
@ -152,9 +167,6 @@ class NodeDefBuilder {
return input_arg->is_ref() ? MakeRefType(dt) : dt;
}
void CheckInconsistency(StringPiece attr_name, const AttrValue& found,
const AttrValue& attr_value);
const OpDef* op_def_;
NodeDef node_def_;
int inputs_specified_;
@ -162,21 +174,6 @@ class NodeDefBuilder {
std::vector<string> errors_;
};
// IMPLEMENTATION -------------------------------------------------------------
template <class T>
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece attr_name, T&& value) {
const AttrValue* found = AttrSlice(node_def_).Find(attr_name);
if (found == nullptr) {
AddNodeAttr(attr_name, std::forward<T>(value), &node_def_);
} else {
AttrValue attr_value;
SetAttrValue(std::forward<T>(value), &attr_value);
CheckInconsistency(attr_name, *found, attr_value);
}
return *this;
}
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_def.pb_text.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/tensor.pb_text.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"

View File

@ -21,7 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_def.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/protobuf.h"
@ -30,8 +30,9 @@ namespace tensorflow {
class Node;
// We forward declare NodeDef so that kernels don't need to depend on protos
// We forward declare protos so that kernels don't need to depend on them
class NodeDef;
class OpDef;
// Name of the attribute used to encode node colocation constraints.
//

View File

@ -20,7 +20,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_def.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/selective_registration.h"

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <limits>
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/types.h"

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <set>
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/op_def.pb_text.h"
#include "tensorflow/core/framework/types.h"

View File

@ -17,6 +17,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_gen_overrides.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@ -71,6 +73,9 @@ bool ConsumeEquals(StringPiece* description) {
return false;
}
OpGenOverrideMap::OpGenOverrideMap() {}
OpGenOverrideMap::~OpGenOverrideMap() {}
Status OpGenOverrideMap::LoadFileList(Env* env, const string& filenames) {
std::vector<string> v = str_util::Split(filenames, ",");
for (const string& f : v) {
@ -86,7 +91,7 @@ Status OpGenOverrideMap::LoadFile(Env* env, const string& filename) {
OpGenOverrides all;
protobuf::TextFormat::ParseFromString(contents, &all);
for (const auto& one : all.op()) {
map_[one.name()] = one;
map_[one.name()].reset(new OpGenOverride(one));
}
return Status::OK();
}
@ -142,7 +147,7 @@ const OpGenOverride* OpGenOverrideMap::ApplyOverride(OpDef* op_def) const {
// Look up
const auto iter = map_.find(op_def->name());
if (iter == map_.end()) return nullptr;
const OpGenOverride& proto = iter->second;
const OpGenOverride& proto = *iter->second;
// Apply overrides from `proto`.
if (!proto.rename_to().empty()) {

View File

@ -18,14 +18,18 @@ limitations under the License.
#include <string>
#include <unordered_map>
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_gen_overrides.pb.h"
#include "tensorflow/core/framework/op_def.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/op_gen_overrides.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/env.h"
namespace tensorflow {
// Forward declare protos so their symbols can be removed from .so exports
class OpDef;
class OpGenOverride;
inline string Spaces(int n) { return string(n, ' '); }
// Wrap prefix + str to be at most width characters, indenting every line
@ -43,6 +47,9 @@ bool ConsumeEquals(StringPiece* description);
// look up the specific override for any given op.
class OpGenOverrideMap {
public:
OpGenOverrideMap();
~OpGenOverrideMap();
// `filenames` is a comma-separated list of file names. If an op
// is mentioned in more than one file, the last one takes priority.
Status LoadFileList(Env* env, const string& filenames);
@ -61,7 +68,7 @@ class OpGenOverrideMap {
const OpGenOverride* ApplyOverride(OpDef* op_def) const;
private:
std::unordered_map<string, OpGenOverride> map_;
std::unordered_map<string, std::unique_ptr<OpGenOverride>> map_;
};
} // namespace tensorflow

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/graph.pb_text.h"
#include "tensorflow/core/framework/kernel_def.pb_text.h"
#include "tensorflow/core/framework/log_memory.h"

View File

@ -24,19 +24,19 @@ limitations under the License.
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/control_flow.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/function.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/kernel_def.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/session_state.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/tracking_allocator.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
@ -65,9 +65,13 @@ class TensorSliceReaderCacheWrapper;
} // namespace checkpoint
class AsyncOpKernel;
class FunctionCallFrame;
class FunctionLibraryRuntime;
class OpKernelConstruction; // declared below
class OpKernelContext; // declared below
class OpRegistryInterface;
class ResourceMgr;
class ScopedStepContainer;
class OpKernel {
public:

View File

@ -19,10 +19,12 @@ limitations under the License.
#include <utility>
#include <vector>
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/reader_base.h"
#include "tensorflow/core/framework/reader_base.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"

View File

@ -19,12 +19,14 @@ limitations under the License.
#include <memory>
#include <string>
#include "tensorflow/core/framework/queue_interface.h"
#include "tensorflow/core/framework/reader_base.pb.h"
#include "tensorflow/core/framework/reader_base.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/reader_interface.h"
#include "tensorflow/core/lib/core/stringpiece.h"
namespace tensorflow {
class ReaderBaseState;
// Default implementation of ReaderInterface.
class ReaderBase : public ReaderInterface {
public:

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/errors.h"

View File

@ -22,8 +22,9 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/errors.h"

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb_text.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
@ -79,6 +80,58 @@ InferenceContext::InferenceContext(
PostInputInit(std::move(handle_data));
}
// Same as above, but with PartialTensorShape instead of TensorShapeProto
InferenceContext::InferenceContext(
int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
const std::vector<PartialTensorShape>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
const std::vector<PartialTensorShape>& input_tensors_as_shapes,
const std::vector<
std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>&
input_handle_shapes_and_types)
: graph_def_version_(graph_def_version),
node_def_(*CHECK_NOTNULL(node_def)) {
std::vector<ShapeHandle> input_tensors_as_shape_handles;
for (const PartialTensorShape& p : input_tensors_as_shapes) {
ShapeHandle shape;
construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
if (!construction_status_.ok()) {
return;
}
input_tensors_as_shape_handles.push_back(shape);
}
PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
if (!construction_status_.ok()) return;
for (const PartialTensorShape& p : input_shapes) {
ShapeHandle shape;
construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
if (!construction_status_.ok()) {
return;
}
inputs_.push_back(shape);
}
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
input_shapes.size());
for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) {
const auto& v = input_handle_shapes_and_types[i];
if (v == nullptr) {
continue;
}
handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
auto& new_v = *handle_data[i];
for (int j = 0; j < v->size(); ++j) {
const auto& p = (*v)[j];
construction_status_.Update(
MakeShapeFromPartialTensorShape(p.first, &new_v[j].shape));
if (!construction_status_.ok()) {
return;
}
new_v[j].dtype = p.second;
}
}
PostInputInit(std::move(handle_data));
}
InferenceContext::InferenceContext(
int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
const std::vector<ShapeHandle>& input_shapes,

View File

@ -17,7 +17,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
@ -189,6 +189,26 @@ class InferenceContext {
std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>&
input_handle_shapes_and_types);
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
//
// Elements of <input_tensors_as_shapes> are used for when a shape
// function makes a call to MakeShapeFromShapeTensor; in particular, when
// the input_tensors[i] is nullptr but the shape represented by it is
// partially known from analysis of the graph. <input_tensors_as_shapes>
// can have fewer elements than <input_shapes>. Values of
// <input_tensors_as_shapes> do not need to outlive the context.
//
// REQUIRES: <node_def> is not NULL, and must outlive the
// InferenceContext.
InferenceContext(
int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
const std::vector<PartialTensorShape>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
const std::vector<PartialTensorShape>& input_tensors_as_shapes,
const std::vector<std::unique_ptr<
std::vector<std::pair<PartialTensorShape, DataType>>>>&
input_handle_shapes_and_types);
~InferenceContext();
// Runs the shape inference function 'fn' with 'this' as the

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -36,19 +37,11 @@ OpDef MakeOpDefWithLists() {
return op_reg_data.op_def;
}
TensorShapeProto S(std::initializer_list<int64> dims) {
PartialTensorShape shape(dims);
TensorShapeProto ret;
shape.AsProto(&ret);
return ret;
PartialTensorShape S(std::initializer_list<int64> dims) {
return PartialTensorShape(dims);
}
TensorShapeProto Unknown() {
PartialTensorShape shape;
TensorShapeProto ret;
shape.AsProto(&ret);
return ret;
}
PartialTensorShape Unknown() { return PartialTensorShape(); }
} // namespace
@ -1537,7 +1530,7 @@ void ShapeInferenceTest::TestMergeHandles(bool input_not_output) {
{});
auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
ShapeHandle s;
TF_CHECK_OK(c.MakeShapeFromShapeProto(S(dim_sizes), &s));
TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
return s;
};
auto get_shapes_and_types_from_context = [&](int idx) {
@ -1648,7 +1641,7 @@ void ShapeInferenceTest::TestRelaxHandles(bool input_not_output) {
{});
auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
ShapeHandle s;
TF_CHECK_OK(c.MakeShapeFromShapeProto(S(dim_sizes), &s));
TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
return s;
};
auto get_shapes_and_types_from_context = [&](int idx) {

View File

@ -16,7 +16,6 @@ limitations under the License.
#define THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_
#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/status.h"
@ -28,7 +27,6 @@ limitations under the License.
namespace tensorflow {
class NodeDef;
class Tensor;
struct ShapeInferenceTestOp {

View File

@ -29,9 +29,11 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/resource_handle.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/coding.h"

View File

@ -17,10 +17,10 @@ limitations under the License.
#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/allocation_description.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/framework/tensor.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/tensor_description.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
@ -35,8 +35,13 @@ limitations under the License.
namespace tensorflow {
class TensorBuffer; // Forward declaration.
// Forward declarations. In particular, we forward declare protos so that their
// symbols can be removed from .so exports.
class AllocationDescription;
class TensorBuffer;
class TensorCApi;
class TensorDescription;
class TensorProto;
/// @ingroup core
/// Represents an n-dimensional array of values.

View File

@ -16,7 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_FRAMEWORK_TENSOR_REFERENCE_H_
#define TENSORFLOW_FRAMEWORK_TENSOR_REFERENCE_H_
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"

View File

@ -19,7 +19,7 @@ limitations under the License.
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
@ -35,6 +35,7 @@ namespace tensorflow {
template <class Shape>
class TensorShapeIter;
class TensorShape;
class TensorShapeProto;
class PartialTensorShape;
// END_SKIP_DOXYGEN

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/lib/strings/str_util.h"

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/strings/strcat.h"

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/versions.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/public/version.h"

View File

@ -16,11 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_FRAMEWORK_VERSIONS_H_
#define TENSORFLOW_FRAMEWORK_VERSIONS_H_
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/framework/versions.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class VersionDef;
// Check whether data with the given versions is compatible with the given
// consumer and min producer. upper_name and lower_name are used to form
// error messages upon failure. Example usage:

View File

@ -16,8 +16,10 @@ limitations under the License.
#include "tensorflow/core/graph/costmodel.h"
#include <vector>
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/logging.h"

View File

@ -16,9 +16,11 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@ -255,9 +257,11 @@ Status Node::input_node(int idx, const Node** const_n) const {
// Graph
Graph::Graph(const OpRegistryInterface* ops)
: ops_(ops, FunctionDefLibrary()), arena_(8 << 10 /* 8kB */) {
versions_.set_producer(TF_GRAPH_DEF_VERSION);
versions_.set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
: ops_(ops, FunctionDefLibrary()),
versions_(new VersionDef),
arena_(8 << 10 /* 8kB */) {
versions_->set_producer(TF_GRAPH_DEF_VERSION);
versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
// Initialize the name interning table for assigned_device_name.
device_names_.push_back("");
@ -301,6 +305,9 @@ Graph::~Graph() {
// destroy them.
}
const VersionDef& Graph::versions() const { return *versions_; }
void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
Node* Graph::AddNode(const NodeDef& node_def, Status* status) {
const OpDef* op_def;
status->Update(ops_.LookUpOpDef(node_def.op(), &op_def));

View File

@ -41,10 +41,10 @@ limitations under the License.
#include <string>
#include <vector>
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/framework/versions.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/graph/edgeset.h"
#include "tensorflow/core/lib/core/arena.h"
#include "tensorflow/core/lib/core/refcount.h"
@ -59,7 +59,9 @@ namespace tensorflow {
class Edge;
class EdgeSetTest;
class Graph;
class GraphDef;
class Node;
class VersionDef;
class NeighborIter; // Declared below
class NodeIter; // Declared below
@ -370,8 +372,8 @@ class Graph {
static const int kControlSlot;
// The GraphDef version range of this graph (see graph.proto).
const VersionDef& versions() const { return versions_; }
void set_versions(const VersionDef& versions) { versions_ = versions; }
const VersionDef& versions() const;
void set_versions(const VersionDef& versions);
// Adds a new node to this graph, and returns it. Infers the Op and
// input/output types for the node. *this owns the returned instance.
@ -514,7 +516,7 @@ class Graph {
FunctionLibraryDefinition ops_;
// GraphDef versions
VersionDef versions_;
const std::unique_ptr<VersionDef> versions_;
// Allocator which will give us good locality.
core::Arena arena_;

View File

@ -27,8 +27,10 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/versions.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/tensor_id.h"

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/kernels/ops_util.h"

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"

View File

@ -22,7 +22,9 @@ limitations under the License.
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/costmodel.h"

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
#include "tensorflow/core/grappler/utils.h"

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
#include "tensorflow/core/platform/test.h"

View File

@ -60,6 +60,7 @@ cc_test(
"//tensorflow/cc:scope",
"//tensorflow/core:framework",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
@ -196,6 +197,7 @@ cc_test(
":virtual_placer",
":virtual_scheduler",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
@ -232,6 +234,7 @@ cc_library(
":cost_estimator",
":op_performance_data_cc",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler/clusters:utils",
"//third_party/eigen3",
],
@ -265,6 +268,7 @@ cc_library(
":virtual_scheduler",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
],
)

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <limits>
#include <unordered_map>
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/graph/types.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"

View File

@ -179,7 +179,7 @@ Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
Status GraphProperties::InferStatically() {
Graph graph(OpRegistry::Global());
ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry());
ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
shape_refiner.set_require_shape_inference_fns(false);
ImportGraphDefOptions options;
Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner);

Some files were not shown because too many files have changed in this diff Show More