From a6ac9040dd2d447247031b8f92ce1ca62b159291 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev <slebedev@google.com> Date: Mon, 7 Oct 2019 14:49:17 -0700 Subject: [PATCH] Changed InferenceContext ctor to accept NodeDef as a const ref It has been a const ref once, but cl/170078811 made it a non-const pointer to allow resetting it in ShapeRefiner::InferShapesForFunction. This code path is no longer used. PiperOrigin-RevId: 273381405 --- tensorflow/c/c_api_experimental.cc | 2 +- tensorflow/c/kernels/bitcast_op_test.cc | 6 +- tensorflow/c/ops_test.cc | 10 +- .../tensorflow/transforms/shape_inference.cc | 2 +- .../common_runtime/eager/shape_inference.cc | 2 +- .../core/common_runtime/shape_refiner.cc | 2 +- .../core/framework/common_shape_fns_test.cc | 81 +++++++------ tensorflow/core/framework/shape_inference.cc | 20 ++-- tensorflow/core/framework/shape_inference.h | 15 +-- .../core/framework/shape_inference_test.cc | 112 +++++++++--------- .../framework/shape_inference_testutil.cc | 2 +- .../core/grappler/costs/graph_properties.cc | 2 +- tensorflow/core/ops/array_ops_test.cc | 2 +- tensorflow/core/ops/math_ops_test.cc | 2 +- 14 files changed, 126 insertions(+), 134 deletions(-) diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 12e714ea9e1..31a0f0e97f6 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -1111,7 +1111,7 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, } // Create an inference context with dummy values, which will be updated later. - InferenceContext c(TF_GRAPH_DEF_VERSION, &node_def, op_reg_data->op_def, + InferenceContext c(TF_GRAPH_DEF_VERSION, node_def, op_reg_data->op_def, std::vector<ShapeHandle>(num_inputs), input_tensors_vector, {}, std::vector<std::unique_ptr<std::vector<ShapeAndType>>>()); diff --git a/tensorflow/c/kernels/bitcast_op_test.cc b/tensorflow/c/kernels/bitcast_op_test.cc index 7e6dbe14725..7da27e99d1f 100644 --- a/tensorflow/c/kernels/bitcast_op_test.cc +++ b/tensorflow/c/kernels/bitcast_op_test.cc @@ -114,7 +114,7 @@ TEST(BitcastOpTest, TestShapeInference_LargerShape) { .Attr("T", DT_INT64) .Input(FakeInput(DT_INT64)) .Finalize(&def)); - shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4})}, {}, {}, {}); + shape_inference::InferenceContext c(0, def, op_def, {S({3, 4})}, {}, {}, {}); std::vector<shape_inference::ShapeHandle> input_shapes; TF_CHECK_OK(c.input("input", &input_shapes)); ASSERT_EQ("[3,4]", c.DebugString(input_shapes[0])); @@ -132,7 +132,7 @@ TEST(BitcastOpTest, TestShapeInference_SmallerShape) { .Attr("T", DT_INT8) .Input(FakeInput(DT_INT8)) .Finalize(&def)); - shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4, 8})}, {}, {}, + shape_inference::InferenceContext c(0, def, op_def, {S({3, 4, 8})}, {}, {}, {}); std::vector<shape_inference::ShapeHandle> input_shapes; TF_CHECK_OK(c.input("input", &input_shapes)); @@ -151,7 +151,7 @@ TEST(BitcastOpTest, TestShapeInference_SameShape) { .Attr("T", DT_FLOAT) .Input(FakeInput(DT_FLOAT)) .Finalize(&def)); - shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4})}, {}, {}, {}); + shape_inference::InferenceContext c(0, def, op_def, {S({3, 4})}, {}, {}, {}); std::vector<shape_inference::ShapeHandle> input_shapes; TF_CHECK_OK(c.input("input", &input_shapes)); ASSERT_EQ("[3,4]", c.DebugString(input_shapes[0])); diff --git a/tensorflow/c/ops_test.cc b/tensorflow/c/ops_test.cc index 0a6c5cd50fb..2e0a8e92b01 100644 --- a/tensorflow/c/ops_test.cc +++ b/tensorflow/c/ops_test.cc @@ -196,7 +196,7 @@ PartialTensorShape Unknown() { return PartialTensorShape(); } TEST(OpsTest, ShapeInferenceWithRank) { NodeDef def; - shape_inference::InferenceContext c(0, &def, MakeOpDef(1, 0), + shape_inference::InferenceContext c(0, def, MakeOpDef(1, 0), {S({10, 20, 30})}, {}, {}, {}); shape_inference::ShapeHandle in0 = c.input(0); @@ -236,7 +236,7 @@ TEST(OpsTest, ShapeInferenceWithRank) { TEST(OpsTest, ShapeInferenceWithRank_UnknownRank) { NodeDef def; - shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 2), + shape_inference::InferenceContext c(0, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {}, {}); shape_inference::ShapeHandle in0 = c.input(0); @@ -260,7 +260,7 @@ TEST(OpsTest, ShapeInferenceWithRank_UnknownRank) { TEST(OpsTest, ShapeInferenceConcatenateShapes) { NodeDef def; - shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 0), + shape_inference::InferenceContext c(0, def, MakeOpDef(2, 0), {S({1, 2}), S({3, 4})}, {}, {}, {}); ASSERT_EQ(2, TF_ShapeInferenceContextNumInputs(C_CTX(&c))); shape_inference::ShapeHandle a = c.input(0); @@ -279,7 +279,7 @@ TEST(OpsTest, ShapeInferenceConcatenateShapes) { TEST(OpsTest, DimensionHandleValueKnown) { NodeDef def; - shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 0), + shape_inference::InferenceContext c(0, def, MakeOpDef(2, 0), {S({1, 2}), S({3, 4})}, {}, {}, {}); TF_ShapeHandle* handle = TF_ShapeInferenceContextVectorFromSize(C_CTX(&c), 43); @@ -299,7 +299,7 @@ TEST(OpsTest, DimensionHandleValueKnown) { TEST(OpsTest, ShapeInferenceSubshape) { NodeDef def; - shape_inference::InferenceContext c(0, &def, MakeOpDef(1, 0), + shape_inference::InferenceContext c(0, def, MakeOpDef(1, 0), {S({10, 20, 30, 40, 50})}, {}, {}, {}); ASSERT_EQ("[10,20,30,40,50]", c.DebugString(c.input(0))); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 54fc64419bb..813460c5ab9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -115,7 +115,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, // shapes. This object is abstracting the information that the ShapeInference // function operates on. tensorflow::shape_inference::InferenceContext c( - graph_version, node_def.get(), op_reg_data->op_def, input_shapes, + graph_version, *node_def, op_reg_data->op_def, input_shapes, /*input_tensors=*/{}, /*input_tensors_as_shapes=*/{}, /*input_handle_shapes_and_types=*/{}); auto status = c.Run(op_reg_data->shape_inference_fn); diff --git a/tensorflow/core/common_runtime/eager/shape_inference.cc b/tensorflow/core/common_runtime/eager/shape_inference.cc index 43ef02179a0..6df2d8c5465 100644 --- a/tensorflow/core/common_runtime/eager/shape_inference.cc +++ b/tensorflow/core/common_runtime/eager/shape_inference.cc @@ -37,7 +37,7 @@ Status RunShapeInference(const NodeDef& ndef, if (op_reg_data->shape_inference_fn == nullptr) return Status::OK(); shape_inference::InferenceContext ic( - TF_GRAPH_DEF_VERSION, &ndef, op_reg_data->op_def, + TF_GRAPH_DEF_VERSION, ndef, op_reg_data->op_def, std::vector<shape_inference::ShapeHandle>(inputs.size()), {}, {}, {}); for (size_t i = 0; i < inputs.size(); i++) { shape_inference::ShapeHandle shape; diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 73a84852a1a..e14c253c71c 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -191,7 +191,7 @@ Status ShapeRefiner::InferShapesForFunction( Status ShapeRefiner::AddNode(const Node* node) { // Create the inference context for this node with the existing input shapes. std::unique_ptr<InferenceContext> ic(new InferenceContext( - graph_def_version_, &node->def(), node->op_def(), + graph_def_version_, node->def(), node->op_def(), std::vector<ShapeHandle>(node->num_inputs()), {}, {}, {})); TF_RETURN_IF_ERROR(ic->construction_status()); diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index 19642efe389..68c448c8007 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -63,7 +63,7 @@ TEST(CommonShapeFnsTest, NoOutputShapeTest) { .Input({{"data", 0, DT_FLOAT}}) .Finalize(&def)); - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({}), S({10})}, {}, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({}), S({10})}, {}, {}, {}); TF_EXPECT_OK(NoOutputs(&c)); EXPECT_EQ(0, c.num_outputs()); @@ -82,15 +82,15 @@ TEST(CommonShapeFnsTest, ScalarShapeTest) { NodeDefBuilder("test", "L2Loss").Input("t", 0, DT_FLOAT).Finalize(&def)); { - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({})}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({})}, {}, {}, {}); TF_EXPECT_OK(ScalarShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(0, c.Rank(output)); } { - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, - {S({1, 23, 4, 4, 2})}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({1, 23, 4, 4, 2})}, + {}, {}, {}); TF_EXPECT_OK(ScalarShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(0, c.Rank(output)); @@ -117,7 +117,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { .Finalize(&def)); { - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 3}), S({3, 4})}, {}, {}, {}); TF_EXPECT_OK(MatMulShape(&c)); ShapeHandle output = c.output(0); @@ -127,7 +127,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { { // Unknown inner dimension for one - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, -1}), S({3, 4})}, {}, {}, {}); TF_EXPECT_OK(MatMulShape(&c)); ShapeHandle output = c.output(0); @@ -137,7 +137,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { { // Invalid rank. - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2}), S({3, 4})}, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2}), S({3, 4})}, {}, {}, {}); auto s = MatMulShape(&c); EXPECT_FALSE(s.ok()); @@ -147,7 +147,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { { // Unknown outer dimension - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 3}), S({3, -1})}, {}, {}, {}); TF_EXPECT_OK(MatMulShape(&c)); ShapeHandle output = c.output(0); @@ -157,7 +157,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { { // Inner shapes not compatible - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 5}), S({3, 4})}, {}, {}, {}); auto s = MatMulShape(&c); EXPECT_FALSE(s.ok()); @@ -168,7 +168,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { { // Inner shapes not compatible - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {}); auto s = MatMulShape(&c); EXPECT_FALSE(s.ok()); @@ -186,7 +186,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { .Attr("type", DT_FLOAT) .Finalize(&def)); - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({3, 2}), S({3, 4})}, {}, {}, {}); auto s = MatMulShape(&c); ShapeHandle output = c.output(0); @@ -204,7 +204,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { .Attr("type", DT_FLOAT) .Finalize(&def)); - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 3}), S({4, 3})}, {}, {}, {}); auto s = MatMulShape(&c); ShapeHandle output = c.output(0); @@ -420,8 +420,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { .Finalize(&def)); { - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, - {S({2, 10}), S({10})}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 10}), S({10})}, + {}, {}, {}); TF_EXPECT_OK(BiasAddShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(2, c.Value(c.Dim(output, 0))); @@ -430,7 +430,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { { // Unknown ranks. - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {Unknown(), Unknown()}, {}, {}, {}); TF_EXPECT_OK(BiasAddShape(&c)); ShapeHandle output = c.output(0); @@ -439,7 +439,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { { // Rank > 2 - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {}); TF_EXPECT_OK(BiasAddShape(&c)); ShapeHandle output = c.output(0); @@ -453,7 +453,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { .Input("b", 0, DT_FLOAT) .Attr("data_format", "NCHW") .Finalize(&def)); - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 3, 4, 5}), S({3})}, {}, {}, {}); TF_EXPECT_OK(BiasAddShape(&c)); ShapeHandle output = c.output(0); @@ -467,7 +467,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { .Input("b", 0, DT_FLOAT) .Attr("data_format", "NCHW") .Finalize(&def)); - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {}, {}); EXPECT_FALSE(BiasAddShape(&c).ok()); } @@ -479,7 +479,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { .Input("b", 0, DT_FLOAT) .Attr("data_format", "NCHW") .Finalize(&def)); - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({10, 11, 12}), S({11})}, {}, {}, {}); TF_EXPECT_OK(BiasAddShape(&c)); ShapeHandle output = c.output(0); @@ -488,7 +488,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { { // Input rank not high enough - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3}), S({3})}, {}, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({3}), S({3})}, {}, {}, {}); EXPECT_FALSE(BiasAddShape(&c).ok()); } @@ -501,7 +501,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { .Attr("data_format", "NCHW") .Finalize(&def)); // NCHW format - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3}), S({3})}, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 3}), S({3})}, {}, {}, {}); EXPECT_FALSE(BiasAddShape(&c).ok()); } @@ -548,7 +548,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { .Finalize(&def)); { - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 10})}, {}, {}, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 10})}, {}, {}, {}); TF_EXPECT_OK(BiasAddGradShape(&c)); ShapeHandle output = c.output(0); @@ -557,7 +557,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { { // Rank > 2 - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({5, 7, 2, 10})}, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({5, 7, 2, 10})}, {}, {}, {}); TF_EXPECT_OK(BiasAddGradShape(&c)); ShapeHandle output = c.output(0); @@ -570,8 +570,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { .Input("a", 0, DT_FLOAT) .Attr("data_format", "NCHW") .Finalize(&def)); - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3, 4, 5})}, - {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 3, 4, 5})}, {}, + {}, {}); TF_EXPECT_OK(BiasAddGradShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(3, c.Value(c.Dim(output, 0))); @@ -583,7 +583,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { .Input("a", 0, DT_FLOAT) .Attr("data_format", "NCHW") .Finalize(&def)); - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {}); TF_EXPECT_OK(BiasAddGradShape(&c)); ShapeHandle output = c.output(0); @@ -596,8 +596,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { .Input("a", 0, DT_FLOAT) .Attr("data_format", "NCHW") .Finalize(&def)); - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({10, 11, 12})}, - {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({10, 11, 12})}, {}, + {}, {}); TF_EXPECT_OK(BiasAddGradShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(11, c.Value(c.Dim(output, 0))); @@ -605,8 +605,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { { // Input rank not high enough - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3})}, {}, {}, - {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({3})}, {}, {}, {}); EXPECT_FALSE(BiasAddGradShape(&c).ok()); } @@ -617,7 +616,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { .Attr("data_format", "NCHW") .Finalize(&def)); // NCHW format - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3})}, {}, {}, + InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 3})}, {}, {}, {}); EXPECT_FALSE(BiasAddGradShape(&c).ok()); } @@ -1353,7 +1352,7 @@ TEST(CommonShapeFnsTest, Reduce_ShapeFn) { TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) { NodeDef def; - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -1366,7 +1365,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) { TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) { NodeDef def; - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -1379,7 +1378,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) { TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) { NodeDef def; - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -1393,7 +1392,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) { TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) { NodeDef def; - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -1407,7 +1406,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) { TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) { NodeDef def; - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -1421,7 +1420,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) { TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) { NodeDef def; - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -1434,7 +1433,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) { TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) { NodeDef def; - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -1447,7 +1446,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) { TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) { NodeDef def; - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -1460,7 +1459,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) { TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) { NodeDef def; - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -1473,7 +1472,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) { TEST(CommonShapeFnsTest, ValidateSparseTensor) { NodeDef def; - InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 064cc752291..5630898da6b 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -31,15 +31,14 @@ constexpr int64 InferenceContext::kUnknownDim; // Same as above, but with PartialTensorShape instead of TensorShapeProto InferenceContext::InferenceContext( - int graph_def_version, const NodeDef* node_def, const OpDef& op_def, + int graph_def_version, const NodeDef& node_def, const OpDef& op_def, const std::vector<PartialTensorShape>& input_shapes, const std::vector<const Tensor*>& input_tensors, const std::vector<PartialTensorShape>& input_tensors_as_shapes, const std::vector< std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>& input_handle_shapes_and_types) - : graph_def_version_(graph_def_version), - node_def_(CHECK_NOTNULL(node_def)) { + : graph_def_version_(graph_def_version), node_def_(node_def) { std::vector<ShapeHandle> input_tensors_as_shape_handles; input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size()); for (const PartialTensorShape& p : input_tensors_as_shapes) { @@ -84,14 +83,13 @@ InferenceContext::InferenceContext( } InferenceContext::InferenceContext( - int graph_def_version, const NodeDef* node_def, const OpDef& op_def, + int graph_def_version, const NodeDef& node_def, const OpDef& op_def, const std::vector<ShapeHandle>& input_shapes, const std::vector<const Tensor*>& input_tensors, const std::vector<ShapeHandle>& input_tensors_as_shapes, std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types) - : graph_def_version_(graph_def_version), - node_def_(CHECK_NOTNULL(node_def)) { + : graph_def_version_(graph_def_version), node_def_(node_def) { PreInputInit(op_def, input_tensors, input_tensors_as_shapes); if (!construction_status_.ok()) return; inputs_ = input_shapes; @@ -112,7 +110,7 @@ Status InferenceContext::Run( #ifndef NDEBUG for (int i = 0; i < num_outputs(); ++i) { DCHECK(output(i).IsSet()) - << i << " for " << node_def_->name() << " of type " << node_def_->op(); + << i << " for " << node_def_.name() << " of type " << node_def_.op(); } #endif // NDEBUG return s; @@ -171,8 +169,8 @@ void InferenceContext::PreInputInit( input_tensors_ = input_tensors; input_tensors_as_shapes_ = input_tensors_as_shapes; - construction_status_ = NameRangesForNode(*node_def_, op_def, &input_name_map_, - &output_name_map_); + construction_status_ = + NameRangesForNode(node_def_, op_def, &input_name_map_, &output_name_map_); if (!construction_status_.ok()) return; int num_outputs = 0; @@ -290,7 +288,7 @@ string InferenceContext::DebugString(DimensionHandle d) { string InferenceContext::DebugString() const { return strings::StrCat("InferenceContext for node: ", - node_def_->DebugString()); + node_def_.DebugString()); } string InferenceContext::DebugString(const ShapeAndType& shape_and_type) { @@ -1119,7 +1117,7 @@ Status InferenceContext::AttachContext(const Status& status) { } string error_context = strings::StrCat( - " for '", node_def_->name(), "' (op: '", node_def_->op(), + " for '", node_def_.name(), "' (op: '", node_def_.op(), "') with input shapes: ", absl::StrJoin(input_shapes, ", ")); if (!input_from_tensors_str.empty()) { strings::StrAppend(&error_context, " and with computed input tensors: ", diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index a654f595e23..a140657eb15 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -161,9 +161,7 @@ class InferenceContext { // known from analysis of the graph. // <input_tensors_as_shapes> can have fewer elements than <input_shapes>. // Values of <input_tensors_as_shapes> do not need to outlive the context. - // - // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext. - InferenceContext(int graph_def_version, const NodeDef* node_def, + InferenceContext(int graph_def_version, const NodeDef& node_def, const OpDef& op_def, const std::vector<ShapeHandle>& input_shapes, const std::vector<const Tensor*>& input_tensors, @@ -179,11 +177,8 @@ class InferenceContext { // partially known from analysis of the graph. <input_tensors_as_shapes> // can have fewer elements than <input_shapes>. Values of // <input_tensors_as_shapes> do not need to outlive the context. - // - // REQUIRES: <node_def> is not NULL, and must outlive the - // InferenceContext. InferenceContext( - int graph_def_version, const NodeDef* node_def, const OpDef& op_def, + int graph_def_version, const NodeDef& node_def, const OpDef& op_def, const std::vector<PartialTensorShape>& input_shapes, const std::vector<const Tensor*>& input_tensors, const std::vector<PartialTensorShape>& input_tensors_as_shapes, @@ -306,7 +301,7 @@ class InferenceContext { Status output(StringPiece output_name, std::vector<ShapeHandle>* output) const; - AttrSlice attrs() const { return AttrSlice(*node_def_); } + AttrSlice attrs() const { return AttrSlice(node_def_); } // idx can be negative for an offset from end of dimensions. // idx must be in the range [-1 * s.rank, s.rank). @@ -737,7 +732,7 @@ class InferenceContext { output_handle_shapes_and_types_; const int graph_def_version_; - const NodeDef* node_def_; + const NodeDef& node_def_; NameRangeMap input_name_map_; NameRangeMap output_name_map_; @@ -784,7 +779,7 @@ inline DimensionOrConstant::DimensionOrConstant(int64 val) : val(val) { template <class T> Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const { - return GetNodeAttr(*node_def_, attr_name, value); + return GetNodeAttr(node_def_, attr_name, value); } } // namespace shape_inference diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index 00bd71a868c..08fec604e2d 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -77,7 +77,7 @@ TEST_F(ShapeInferenceTest, InputOutputByName) { .Attr("N", 3) .Input(FakeInput(DT_FLOAT)) .Finalize(&def); - InferenceContext c(kVersion, &def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, + InferenceContext c(kVersion, def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, {}, {}, {}); EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0)))); @@ -114,7 +114,7 @@ static OpDef MakeOpDef(int num_inputs, int num_outputs) { TEST_F(ShapeInferenceTest, DimensionOrConstant) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {}); EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(InferenceContext::kUnknownDim)); EXPECT_EQ(1, c.Value(1)); @@ -129,7 +129,7 @@ TEST_F(ShapeInferenceTest, Run) { NodeDef def; def.set_name("foo"); def.set_op("foo_op"); - InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1})}, {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1})}, {}, {}, {}); TF_ASSERT_OK(c.construction_status()); { @@ -167,7 +167,7 @@ TEST_F(ShapeInferenceTest, AttachContext) { def.set_op("foo_op"); // Error when no constant tensors were requested. { - InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {}, + InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {}, {}); TF_ASSERT_OK(c.construction_status()); auto fn = [](InferenceContext* c) { @@ -186,7 +186,7 @@ TEST_F(ShapeInferenceTest, AttachContext) { { Tensor input_t = ::tensorflow::test::AsTensor<float>({1.1, 2.2, 3.3, 4.4, 5.5}); - InferenceContext c(kVersion, &def, MakeOpDef(2, 2), + InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {}); TF_ASSERT_OK(c.construction_status()); auto fn = [](InferenceContext* c) { @@ -208,7 +208,7 @@ TEST_F(ShapeInferenceTest, AttachContext) { // shapes provided. { Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5}); - InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})}, + InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({3}), S({4})}, {nullptr, &input_t}, {}, {}); TF_ASSERT_OK(c.construction_status()); auto fn = [](InferenceContext* c) { @@ -231,7 +231,7 @@ TEST_F(ShapeInferenceTest, AttachContext) { // shape was provided. { Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5}); - InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})}, + InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({3}), S({4})}, {nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {}); TF_ASSERT_OK(c.construction_status()); auto fn = [](InferenceContext* c) { @@ -254,7 +254,7 @@ TEST_F(ShapeInferenceTest, AttachContext) { TEST_F(ShapeInferenceTest, RankAndDimInspection) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(3, 2), + InferenceContext c(kVersion, def, MakeOpDef(3, 2), {Unknown(), S({1, -1, 3}), S({})}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(2, c.num_outputs()); @@ -295,7 +295,7 @@ TEST_F(ShapeInferenceTest, RankAndDimInspection) { TEST_F(ShapeInferenceTest, NumElements) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(3, 2), + InferenceContext c(kVersion, def, MakeOpDef(3, 2), {Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {}); EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0)))); @@ -309,8 +309,8 @@ TEST_F(ShapeInferenceTest, NumElements) { TEST_F(ShapeInferenceTest, WithRank) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(2, 2), - {Unknown(), S({1, -1, 3})}, {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, + {}, {}, {}); auto in0 = c.input(0); auto in1 = c.input(1); @@ -348,8 +348,8 @@ TEST_F(ShapeInferenceTest, WithRank) { TEST_F(ShapeInferenceTest, WithRankAtMost) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(2, 2), - {Unknown(), S({1, -1, 3})}, {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, + {}, {}, {}); auto in0 = c.input(0); auto in1 = c.input(1); @@ -386,8 +386,8 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) { TEST_F(ShapeInferenceTest, WithRankAtLeast) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(2, 2), - {Unknown(), S({1, -1, 3})}, {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, + {}, {}, {}); auto in0 = c.input(0); auto in1 = c.input(1); @@ -424,7 +424,7 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) { TEST_F(ShapeInferenceTest, WithValue) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {}); auto d0 = c.Dim(c.input(0), 0); auto d1 = c.Dim(c.input(0), 1); @@ -467,8 +467,8 @@ TEST_F(ShapeInferenceTest, WithValue) { TEST_F(ShapeInferenceTest, MergeDim) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, - {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {}, + {}, {}); auto d2 = c.Dim(c.input(0), 0); auto d_unknown = c.Dim(c.input(0), 1); @@ -530,7 +530,7 @@ TEST_F(ShapeInferenceTest, MergeDim) { TEST_F(ShapeInferenceTest, RelaxDim) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(1, 2), + InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({2, InferenceContext::kUnknownDim, 2, 1, InferenceContext::kUnknownDim})}, {}, {}, {}); @@ -578,7 +578,7 @@ TEST_F(ShapeInferenceTest, RelaxDim) { TEST_F(ShapeInferenceTest, RelaxShape) { NodeDef def; InferenceContext c( - kVersion, &def, MakeOpDef(7, 2), + kVersion, def, MakeOpDef(7, 2), {Unknown(), S({1, 2}), S({InferenceContext::kUnknownDim, 2}), S({1, InferenceContext::kUnknownDim}), S({1, 3}), Unknown(), S({1})}, {}, {}, {}); @@ -647,7 +647,7 @@ TEST_F(ShapeInferenceTest, RelaxShape) { TEST_F(ShapeInferenceTest, MergeShape) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(7, 2), + InferenceContext c(kVersion, def, MakeOpDef(7, 2), {Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}), Unknown(), S({1})}, {}, {}, {}); @@ -753,7 +753,7 @@ TEST_F(ShapeInferenceTest, MergeShape) { TEST_F(ShapeInferenceTest, MergePrefix) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(4, 2), + InferenceContext c(kVersion, def, MakeOpDef(4, 2), { Unknown(), S({-1, 2}), @@ -808,7 +808,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) { TEST_F(ShapeInferenceTest, Subshape) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(2, 2), + InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {}); ShapeHandle unknown = c.input(1); @@ -880,7 +880,7 @@ TEST_F(ShapeInferenceTest, Subshape) { TEST_F(ShapeInferenceTest, Concatenate) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(3, 2), + InferenceContext c(kVersion, def, MakeOpDef(3, 2), {S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {}); auto in0 = c.input(0); @@ -907,7 +907,7 @@ TEST_F(ShapeInferenceTest, Concatenate) { TEST_F(ShapeInferenceTest, ReplaceDim) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, + InferenceContext c(kVersion, def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, {}, {}, {}); auto in = c.input(0); @@ -939,7 +939,7 @@ TEST_F(ShapeInferenceTest, ReplaceDim) { TEST_F(ShapeInferenceTest, MakeShape) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, + InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, {}, {}); std::vector<DimensionHandle> dims; @@ -966,7 +966,7 @@ TEST_F(ShapeInferenceTest, MakeShape) { TEST_F(ShapeInferenceTest, UnknownShape) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {}); auto u0 = c.UnknownShape(); auto u1 = c.UnknownShape(); @@ -978,7 +978,7 @@ TEST_F(ShapeInferenceTest, UnknownShape) { TEST_F(ShapeInferenceTest, KnownShapeToProto) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {}); auto s = c.MakeShape({1, 2, 3}); TensorShapeProto proto; @@ -992,7 +992,7 @@ TEST_F(ShapeInferenceTest, KnownShapeToProto) { TEST_F(ShapeInferenceTest, UnknownShapeToProto) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {}); auto u0 = c.UnknownShape(); TensorShapeProto proto; @@ -1005,7 +1005,7 @@ TEST_F(ShapeInferenceTest, UnknownShapeToProto) { TEST_F(ShapeInferenceTest, Scalar) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {}); auto s0 = c.Scalar(); EXPECT_EQ("[]", c.DebugString(s0)); @@ -1016,7 +1016,7 @@ TEST_F(ShapeInferenceTest, Scalar) { TEST_F(ShapeInferenceTest, Vector) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {}); auto s0 = c.Vector(1); EXPECT_EQ("[1]", c.DebugString(s0)); @@ -1032,7 +1032,7 @@ TEST_F(ShapeInferenceTest, Vector) { TEST_F(ShapeInferenceTest, Matrix) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {}); auto s0 = c.Matrix(1, 2); EXPECT_EQ("[1,2]", c.DebugString(s0)); @@ -1054,7 +1054,7 @@ TEST_F(ShapeInferenceTest, Matrix) { TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) { auto create = [&](Tensor* t) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {Unknown()}, {t}, {}, + InferenceContext c(kVersion, def, MakeOpDef(1, 0), {Unknown()}, {t}, {}, {}); ShapeHandle out; Status s = c.MakeShapeFromShapeTensor(0, &out); @@ -1115,7 +1115,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) { // Test when the input shape is wrong. { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, + InferenceContext c(kVersion, def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, {}, {}); ShapeHandle out; EXPECT_EQ("Shape must be rank 1 but is rank 2", @@ -1126,7 +1126,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) { TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {}); // With an unknown rank. ShapeHandle out; @@ -1145,7 +1145,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) { TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {}); ShapeHandle out; TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape(), &out)); @@ -1159,7 +1159,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) { TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {}); TensorShapeProto proto; // With a set unknown rank. @@ -1195,7 +1195,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) { TEST_F(ShapeInferenceTest, MakeDim) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {}); auto d0 = c.MakeDim(1); auto d1 = c.MakeDim(1); @@ -1209,7 +1209,7 @@ TEST_F(ShapeInferenceTest, MakeDim) { TEST_F(ShapeInferenceTest, UnknownDim) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {}); auto d0 = c.UnknownDim(); auto d1 = c.UnknownDim(); @@ -1221,7 +1221,7 @@ TEST_F(ShapeInferenceTest, UnknownDim) { TEST_F(ShapeInferenceTest, UnknownShapeOfRank) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {}); auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3); EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3)); @@ -1234,7 +1234,7 @@ TEST_F(ShapeInferenceTest, InputTensors) { const Tensor t1 = tensorflow::test::AsTensor<float>({10}); const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30}); NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})}, + InferenceContext c(kVersion, def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})}, {&t1, &t2}, {}, {}); EXPECT_TRUE(c.input_tensor(0) == &t1); @@ -1246,8 +1246,8 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) { Tensor t1 = tensorflow::test::AsScalar<int32>(20); Tensor t2 = tensorflow::test::AsScalar<int32>(-1); NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, - {&t1, &t2}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2}, + {}, {}); DimensionHandle d; EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok()); @@ -1280,7 +1280,7 @@ TEST_F(ShapeInferenceTest, GetAttr) { .ok()); std::vector<ShapeHandle> empty; - InferenceContext c(kVersion, &def, op_reg_data.op_def, empty, {}, {}, {}); + InferenceContext c(kVersion, def, op_reg_data.op_def, empty, {}, {}, {}); string value; EXPECT_TRUE(c.GetAttr("foo", &value).ok()); EXPECT_EQ("bar", value); @@ -1288,7 +1288,7 @@ TEST_F(ShapeInferenceTest, GetAttr) { TEST_F(ShapeInferenceTest, Divide) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, + InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, {}, {}); auto s = c.input(0); @@ -1351,7 +1351,7 @@ TEST_F(ShapeInferenceTest, Divide) { TEST_F(ShapeInferenceTest, Add) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {}, + InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {}, {}); auto s = c.input(0); @@ -1401,8 +1401,8 @@ TEST_F(ShapeInferenceTest, Add) { TEST_F(ShapeInferenceTest, Subtract) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, - {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {}, + {}); auto s = c.input(0); auto d_6 = c.Dim(s, 0); @@ -1451,8 +1451,8 @@ TEST_F(ShapeInferenceTest, Subtract) { TEST_F(ShapeInferenceTest, Multiply) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, - {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {}, + {}); auto s = c.input(0); auto d_6 = c.Dim(s, 0); @@ -1505,7 +1505,7 @@ TEST_F(ShapeInferenceTest, Multiply) { TEST_F(ShapeInferenceTest, FullyDefined) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {}); // No rank or missing dimension information should return false. EXPECT_FALSE(c.FullyDefined(c.UnknownShape())); @@ -1518,8 +1518,8 @@ TEST_F(ShapeInferenceTest, FullyDefined) { TEST_F(ShapeInferenceTest, Min) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, - {}, {}); + InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {}, + {}); auto s = c.input(0); auto d_1 = c.Dim(s, 0); @@ -1567,7 +1567,7 @@ TEST_F(ShapeInferenceTest, Min) { TEST_F(ShapeInferenceTest, Max) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {}, + InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {}, {}); auto s = c.input(0); @@ -1605,7 +1605,7 @@ TEST_F(ShapeInferenceTest, Max) { void ShapeInferenceTest::TestMergeHandles(bool input_not_output) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {}, + InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {}, {}, {}); auto make_shape = [&c](std::initializer_list<int64> dim_sizes) { ShapeHandle s; @@ -1716,7 +1716,7 @@ TEST_F(ShapeInferenceTest, MergeOutputHandleShapesAndTypes) { void ShapeInferenceTest::TestRelaxHandles(bool input_not_output) { NodeDef def; - InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {}, + InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {}, {}, {}); auto make_shape = [&c](std::initializer_list<int64> dim_sizes) { ShapeHandle s; diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc index 2cf447471e3..3145a6e5954 100644 --- a/tensorflow/core/framework/shape_inference_testutil.cc +++ b/tensorflow/core/framework/shape_inference_testutil.cc @@ -60,7 +60,7 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op, } } shape_inference::InferenceContext c( - op.graph_def_version, &op.node_def, op_reg_data->op_def, in_shapes, + op.graph_def_version, op.node_def, op_reg_data->op_def, in_shapes, op.input_tensors, {}, std::move(input_resource_handle_shapes_and_types)); TF_RETURN_IF_ERROR(c.construction_status()); if (op_reg_data->shape_inference_fn == nullptr) { diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 8e6e72993f8..925ea44d454 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -1168,7 +1168,7 @@ class SymbolicShapeRefiner { std::vector<ShapeHandle> input_tensors_as_shapes; node_ctx.inference_context.reset(new InferenceContext( - graph_def_version_, node, node_ctx.op_data->op_def, input_shapes, + graph_def_version_, *node, node_ctx.op_data->op_def, input_shapes, input_tensors, input_tensors_as_shapes, std::move(input_handle_shapes_and_types))); const Status s = node_ctx.inference_context->construction_status(); diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index 4cda29ee2a9..718a34c07e6 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -198,7 +198,7 @@ TEST(ArrayOpsTest, Identity_ShapeFnHandles) { new std::vector<std::pair<PartialTensorShape, DataType>>( {{PartialTensorShape(), DT_BOOL}})); shape_inference::InferenceContext c( - TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def, + TF_GRAPH_DEF_VERSION, op.node_def, op_reg_data->op_def, {PartialTensorShape()}, {}, {}, handle_data); TF_ASSERT_OK(c.construction_status()); ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr); diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index 33a010b9349..7ebd7889a35 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -229,7 +229,7 @@ TEST(MathOpsTest, Select_ShapeFn) { auto run_inference_for_handles = [&]() -> Status { CHECK(op_reg_data->shape_inference_fn != nullptr); c.reset(new shape_inference::InferenceContext( - TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def, + TF_GRAPH_DEF_VERSION, op.node_def, op_reg_data->op_def, {PartialTensorShape(), PartialTensorShape(), PartialTensorShape()}, {}, {}, handle_data)); TF_CHECK_OK(c->construction_status());