From a6ac9040dd2d447247031b8f92ce1ca62b159291 Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev@google.com>
Date: Mon, 7 Oct 2019 14:49:17 -0700
Subject: [PATCH] Changed InferenceContext ctor to accept NodeDef as a const
 ref

It has been a const ref once, but cl/170078811 made it a non-const pointer
to allow resetting it in ShapeRefiner::InferShapesForFunction. This code
path is no longer used.

PiperOrigin-RevId: 273381405
---
 tensorflow/c/c_api_experimental.cc            |   2 +-
 tensorflow/c/kernels/bitcast_op_test.cc       |   6 +-
 tensorflow/c/ops_test.cc                      |  10 +-
 .../tensorflow/transforms/shape_inference.cc  |   2 +-
 .../common_runtime/eager/shape_inference.cc   |   2 +-
 .../core/common_runtime/shape_refiner.cc      |   2 +-
 .../core/framework/common_shape_fns_test.cc   |  81 +++++++------
 tensorflow/core/framework/shape_inference.cc  |  20 ++--
 tensorflow/core/framework/shape_inference.h   |  15 +--
 .../core/framework/shape_inference_test.cc    | 112 +++++++++---------
 .../framework/shape_inference_testutil.cc     |   2 +-
 .../core/grappler/costs/graph_properties.cc   |   2 +-
 tensorflow/core/ops/array_ops_test.cc         |   2 +-
 tensorflow/core/ops/math_ops_test.cc          |   2 +-
 14 files changed, 126 insertions(+), 134 deletions(-)

diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 12e714ea9e1..31a0f0e97f6 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -1111,7 +1111,7 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
   }
 
   // Create an inference context with dummy values, which will be updated later.
-  InferenceContext c(TF_GRAPH_DEF_VERSION, &node_def, op_reg_data->op_def,
+  InferenceContext c(TF_GRAPH_DEF_VERSION, node_def, op_reg_data->op_def,
                      std::vector<ShapeHandle>(num_inputs), input_tensors_vector,
                      {},
                      std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
diff --git a/tensorflow/c/kernels/bitcast_op_test.cc b/tensorflow/c/kernels/bitcast_op_test.cc
index 7e6dbe14725..7da27e99d1f 100644
--- a/tensorflow/c/kernels/bitcast_op_test.cc
+++ b/tensorflow/c/kernels/bitcast_op_test.cc
@@ -114,7 +114,7 @@ TEST(BitcastOpTest, TestShapeInference_LargerShape) {
                   .Attr("T", DT_INT64)
                   .Input(FakeInput(DT_INT64))
                   .Finalize(&def));
-  shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4})}, {}, {}, {});
+  shape_inference::InferenceContext c(0, def, op_def, {S({3, 4})}, {}, {}, {});
   std::vector<shape_inference::ShapeHandle> input_shapes;
   TF_CHECK_OK(c.input("input", &input_shapes));
   ASSERT_EQ("[3,4]", c.DebugString(input_shapes[0]));
@@ -132,7 +132,7 @@ TEST(BitcastOpTest, TestShapeInference_SmallerShape) {
                   .Attr("T", DT_INT8)
                   .Input(FakeInput(DT_INT8))
                   .Finalize(&def));
-  shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4, 8})}, {}, {},
+  shape_inference::InferenceContext c(0, def, op_def, {S({3, 4, 8})}, {}, {},
                                       {});
   std::vector<shape_inference::ShapeHandle> input_shapes;
   TF_CHECK_OK(c.input("input", &input_shapes));
@@ -151,7 +151,7 @@ TEST(BitcastOpTest, TestShapeInference_SameShape) {
                   .Attr("T", DT_FLOAT)
                   .Input(FakeInput(DT_FLOAT))
                   .Finalize(&def));
-  shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4})}, {}, {}, {});
+  shape_inference::InferenceContext c(0, def, op_def, {S({3, 4})}, {}, {}, {});
   std::vector<shape_inference::ShapeHandle> input_shapes;
   TF_CHECK_OK(c.input("input", &input_shapes));
   ASSERT_EQ("[3,4]", c.DebugString(input_shapes[0]));
diff --git a/tensorflow/c/ops_test.cc b/tensorflow/c/ops_test.cc
index 0a6c5cd50fb..2e0a8e92b01 100644
--- a/tensorflow/c/ops_test.cc
+++ b/tensorflow/c/ops_test.cc
@@ -196,7 +196,7 @@ PartialTensorShape Unknown() { return PartialTensorShape(); }
 
 TEST(OpsTest, ShapeInferenceWithRank) {
   NodeDef def;
-  shape_inference::InferenceContext c(0, &def, MakeOpDef(1, 0),
+  shape_inference::InferenceContext c(0, def, MakeOpDef(1, 0),
                                       {S({10, 20, 30})}, {}, {}, {});
 
   shape_inference::ShapeHandle in0 = c.input(0);
@@ -236,7 +236,7 @@ TEST(OpsTest, ShapeInferenceWithRank) {
 
 TEST(OpsTest, ShapeInferenceWithRank_UnknownRank) {
   NodeDef def;
-  shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 2),
+  shape_inference::InferenceContext c(0, def, MakeOpDef(2, 2),
                                       {Unknown(), S({1, -1, 3})}, {}, {}, {});
 
   shape_inference::ShapeHandle in0 = c.input(0);
@@ -260,7 +260,7 @@ TEST(OpsTest, ShapeInferenceWithRank_UnknownRank) {
 
 TEST(OpsTest, ShapeInferenceConcatenateShapes) {
   NodeDef def;
-  shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 0),
+  shape_inference::InferenceContext c(0, def, MakeOpDef(2, 0),
                                       {S({1, 2}), S({3, 4})}, {}, {}, {});
   ASSERT_EQ(2, TF_ShapeInferenceContextNumInputs(C_CTX(&c)));
   shape_inference::ShapeHandle a = c.input(0);
@@ -279,7 +279,7 @@ TEST(OpsTest, ShapeInferenceConcatenateShapes) {
 
 TEST(OpsTest, DimensionHandleValueKnown) {
   NodeDef def;
-  shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 0),
+  shape_inference::InferenceContext c(0, def, MakeOpDef(2, 0),
                                       {S({1, 2}), S({3, 4})}, {}, {}, {});
   TF_ShapeHandle* handle =
       TF_ShapeInferenceContextVectorFromSize(C_CTX(&c), 43);
@@ -299,7 +299,7 @@ TEST(OpsTest, DimensionHandleValueKnown) {
 
 TEST(OpsTest, ShapeInferenceSubshape) {
   NodeDef def;
-  shape_inference::InferenceContext c(0, &def, MakeOpDef(1, 0),
+  shape_inference::InferenceContext c(0, def, MakeOpDef(1, 0),
                                       {S({10, 20, 30, 40, 50})}, {}, {}, {});
   ASSERT_EQ("[10,20,30,40,50]", c.DebugString(c.input(0)));
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index 54fc64419bb..813460c5ab9 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -115,7 +115,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
   // shapes. This object is abstracting the information that the ShapeInference
   // function operates on.
   tensorflow::shape_inference::InferenceContext c(
-      graph_version, node_def.get(), op_reg_data->op_def, input_shapes,
+      graph_version, *node_def, op_reg_data->op_def, input_shapes,
       /*input_tensors=*/{}, /*input_tensors_as_shapes=*/{},
       /*input_handle_shapes_and_types=*/{});
   auto status = c.Run(op_reg_data->shape_inference_fn);
diff --git a/tensorflow/core/common_runtime/eager/shape_inference.cc b/tensorflow/core/common_runtime/eager/shape_inference.cc
index 43ef02179a0..6df2d8c5465 100644
--- a/tensorflow/core/common_runtime/eager/shape_inference.cc
+++ b/tensorflow/core/common_runtime/eager/shape_inference.cc
@@ -37,7 +37,7 @@ Status RunShapeInference(const NodeDef& ndef,
   if (op_reg_data->shape_inference_fn == nullptr) return Status::OK();
 
   shape_inference::InferenceContext ic(
-      TF_GRAPH_DEF_VERSION, &ndef, op_reg_data->op_def,
+      TF_GRAPH_DEF_VERSION, ndef, op_reg_data->op_def,
       std::vector<shape_inference::ShapeHandle>(inputs.size()), {}, {}, {});
   for (size_t i = 0; i < inputs.size(); i++) {
     shape_inference::ShapeHandle shape;
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc
index 73a84852a1a..e14c253c71c 100644
--- a/tensorflow/core/common_runtime/shape_refiner.cc
+++ b/tensorflow/core/common_runtime/shape_refiner.cc
@@ -191,7 +191,7 @@ Status ShapeRefiner::InferShapesForFunction(
 Status ShapeRefiner::AddNode(const Node* node) {
   // Create the inference context for this node with the existing input shapes.
   std::unique_ptr<InferenceContext> ic(new InferenceContext(
-      graph_def_version_, &node->def(), node->op_def(),
+      graph_def_version_, node->def(), node->op_def(),
       std::vector<ShapeHandle>(node->num_inputs()), {}, {}, {}));
   TF_RETURN_IF_ERROR(ic->construction_status());
 
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc
index 19642efe389..68c448c8007 100644
--- a/tensorflow/core/framework/common_shape_fns_test.cc
+++ b/tensorflow/core/framework/common_shape_fns_test.cc
@@ -63,7 +63,7 @@ TEST(CommonShapeFnsTest, NoOutputShapeTest) {
                   .Input({{"data", 0, DT_FLOAT}})
                   .Finalize(&def));
 
-  InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({}), S({10})}, {},
+  InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({}), S({10})}, {},
                      {}, {});
   TF_EXPECT_OK(NoOutputs(&c));
   EXPECT_EQ(0, c.num_outputs());
@@ -82,15 +82,15 @@ TEST(CommonShapeFnsTest, ScalarShapeTest) {
       NodeDefBuilder("test", "L2Loss").Input("t", 0, DT_FLOAT).Finalize(&def));
 
   {
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({})}, {}, {}, {});
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({})}, {}, {}, {});
     TF_EXPECT_OK(ScalarShape(&c));
     ShapeHandle output = c.output(0);
     EXPECT_EQ(0, c.Rank(output));
   }
 
   {
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
-                       {S({1, 23, 4, 4, 2})}, {}, {}, {});
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({1, 23, 4, 4, 2})},
+                       {}, {}, {});
     TF_EXPECT_OK(ScalarShape(&c));
     ShapeHandle output = c.output(0);
     EXPECT_EQ(0, c.Rank(output));
@@ -117,7 +117,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
                   .Finalize(&def));
 
   {
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
                        {S({2, 3}), S({3, 4})}, {}, {}, {});
     TF_EXPECT_OK(MatMulShape(&c));
     ShapeHandle output = c.output(0);
@@ -127,7 +127,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
 
   {
     // Unknown inner dimension for one
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
                        {S({2, -1}), S({3, 4})}, {}, {}, {});
     TF_EXPECT_OK(MatMulShape(&c));
     ShapeHandle output = c.output(0);
@@ -137,7 +137,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
 
   {
     // Invalid rank.
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2}), S({3, 4})},
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2}), S({3, 4})},
                        {}, {}, {});
     auto s = MatMulShape(&c);
     EXPECT_FALSE(s.ok());
@@ -147,7 +147,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
 
   {
     // Unknown outer dimension
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
                        {S({2, 3}), S({3, -1})}, {}, {}, {});
     TF_EXPECT_OK(MatMulShape(&c));
     ShapeHandle output = c.output(0);
@@ -157,7 +157,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
 
   {
     // Inner shapes not compatible
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
                        {S({2, 5}), S({3, 4})}, {}, {}, {});
     auto s = MatMulShape(&c);
     EXPECT_FALSE(s.ok());
@@ -168,7 +168,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
 
   {
     // Inner shapes not compatible
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
                        {S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {});
     auto s = MatMulShape(&c);
     EXPECT_FALSE(s.ok());
@@ -186,7 +186,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
                     .Attr("type", DT_FLOAT)
                     .Finalize(&def));
 
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
                        {S({3, 2}), S({3, 4})}, {}, {}, {});
     auto s = MatMulShape(&c);
     ShapeHandle output = c.output(0);
@@ -204,7 +204,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
                     .Attr("type", DT_FLOAT)
                     .Finalize(&def));
 
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
                        {S({2, 3}), S({4, 3})}, {}, {}, {});
     auto s = MatMulShape(&c);
     ShapeHandle output = c.output(0);
@@ -420,8 +420,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
                   .Finalize(&def));
 
   {
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
-                       {S({2, 10}), S({10})}, {}, {}, {});
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 10}), S({10})},
+                       {}, {}, {});
     TF_EXPECT_OK(BiasAddShape(&c));
     ShapeHandle output = c.output(0);
     EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -430,7 +430,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
 
   {
     // Unknown ranks.
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
                        {Unknown(), Unknown()}, {}, {}, {});
     TF_EXPECT_OK(BiasAddShape(&c));
     ShapeHandle output = c.output(0);
@@ -439,7 +439,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
 
   {
     // Rank > 2
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
                        {S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {});
     TF_EXPECT_OK(BiasAddShape(&c));
     ShapeHandle output = c.output(0);
@@ -453,7 +453,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
                     .Input("b", 0, DT_FLOAT)
                     .Attr("data_format", "NCHW")
                     .Finalize(&def));
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
                        {S({2, 3, 4, 5}), S({3})}, {}, {}, {});
     TF_EXPECT_OK(BiasAddShape(&c));
     ShapeHandle output = c.output(0);
@@ -467,7 +467,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
                     .Input("b", 0, DT_FLOAT)
                     .Attr("data_format", "NCHW")
                     .Finalize(&def));
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
                        {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {}, {});
     EXPECT_FALSE(BiasAddShape(&c).ok());
   }
@@ -479,7 +479,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
                     .Input("b", 0, DT_FLOAT)
                     .Attr("data_format", "NCHW")
                     .Finalize(&def));
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
                        {S({10, 11, 12}), S({11})}, {}, {}, {});
     TF_EXPECT_OK(BiasAddShape(&c));
     ShapeHandle output = c.output(0);
@@ -488,7 +488,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
 
   {
     // Input rank not high enough
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3}), S({3})}, {},
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({3}), S({3})}, {},
                        {}, {});
     EXPECT_FALSE(BiasAddShape(&c).ok());
   }
@@ -501,7 +501,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
                     .Attr("data_format", "NCHW")
                     .Finalize(&def));
     // NCHW format
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3}), S({3})},
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 3}), S({3})},
                        {}, {}, {});
     EXPECT_FALSE(BiasAddShape(&c).ok());
   }
@@ -548,7 +548,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
                   .Finalize(&def));
 
   {
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 10})}, {}, {},
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 10})}, {}, {},
                        {});
     TF_EXPECT_OK(BiasAddGradShape(&c));
     ShapeHandle output = c.output(0);
@@ -557,7 +557,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
 
   {
     // Rank > 2
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({5, 7, 2, 10})},
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({5, 7, 2, 10})},
                        {}, {}, {});
     TF_EXPECT_OK(BiasAddGradShape(&c));
     ShapeHandle output = c.output(0);
@@ -570,8 +570,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
                     .Input("a", 0, DT_FLOAT)
                     .Attr("data_format", "NCHW")
                     .Finalize(&def));
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3, 4, 5})},
-                       {}, {}, {});
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 3, 4, 5})}, {},
+                       {}, {});
     TF_EXPECT_OK(BiasAddGradShape(&c));
     ShapeHandle output = c.output(0);
     EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
@@ -583,7 +583,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
                     .Input("a", 0, DT_FLOAT)
                     .Attr("data_format", "NCHW")
                     .Finalize(&def));
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
                        {S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {});
     TF_EXPECT_OK(BiasAddGradShape(&c));
     ShapeHandle output = c.output(0);
@@ -596,8 +596,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
                     .Input("a", 0, DT_FLOAT)
                     .Attr("data_format", "NCHW")
                     .Finalize(&def));
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({10, 11, 12})},
-                       {}, {}, {});
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({10, 11, 12})}, {},
+                       {}, {});
     TF_EXPECT_OK(BiasAddGradShape(&c));
     ShapeHandle output = c.output(0);
     EXPECT_EQ(11, c.Value(c.Dim(output, 0)));
@@ -605,8 +605,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
 
   {
     // Input rank not high enough
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3})}, {}, {},
-                       {});
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({3})}, {}, {}, {});
     EXPECT_FALSE(BiasAddGradShape(&c).ok());
   }
 
@@ -617,7 +616,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
                     .Attr("data_format", "NCHW")
                     .Finalize(&def));
     // NCHW format
-    InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3})}, {}, {},
+    InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 3})}, {}, {},
                        {});
     EXPECT_FALSE(BiasAddGradShape(&c).ok());
   }
@@ -1353,7 +1352,7 @@ TEST(CommonShapeFnsTest, Reduce_ShapeFn) {
 
 TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) {
   NodeDef def;
-  InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+  InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
                      {Unknown(), Unknown(), Unknown()}, {}, {}, {});
   EXPECT_EQ(3, c.num_inputs());
   EXPECT_EQ(1, c.num_outputs());
@@ -1366,7 +1365,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) {
 
 TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) {
   NodeDef def;
-  InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+  InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
                      {S({-1, -1}), S({-1}), S({-1})}, {}, {}, {});
   EXPECT_EQ(3, c.num_inputs());
   EXPECT_EQ(1, c.num_outputs());
@@ -1379,7 +1378,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) {
 
 TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) {
   NodeDef def;
-  InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+  InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
                      {S({-1}), S({-1}), S({-1})}, {}, {}, {});
   EXPECT_EQ(3, c.num_inputs());
   EXPECT_EQ(1, c.num_outputs());
@@ -1393,7 +1392,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) {
 
 TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) {
   NodeDef def;
-  InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+  InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
                      {S({5, 3}), S({4}), S({3})}, {}, {}, {});
   EXPECT_EQ(3, c.num_inputs());
   EXPECT_EQ(1, c.num_outputs());
@@ -1407,7 +1406,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) {
 
 TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) {
   NodeDef def;
-  InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+  InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
                      {S({5, 3}), S({5}), S({4})}, {}, {}, {});
   EXPECT_EQ(3, c.num_inputs());
   EXPECT_EQ(1, c.num_outputs());
@@ -1421,7 +1420,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) {
 
 TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) {
   NodeDef def;
-  InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+  InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
                      {S({-1, 3}), S({5}), S({3})}, {}, {}, {});
   EXPECT_EQ(3, c.num_inputs());
   EXPECT_EQ(1, c.num_outputs());
@@ -1434,7 +1433,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) {
 
 TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) {
   NodeDef def;
-  InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+  InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
                      {S({5, 3}), S({-1}), S({3})}, {}, {}, {});
   EXPECT_EQ(3, c.num_inputs());
   EXPECT_EQ(1, c.num_outputs());
@@ -1447,7 +1446,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) {
 
 TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) {
   NodeDef def;
-  InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+  InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
                      {S({5, -1}), S({5}), S({3})}, {}, {}, {});
   EXPECT_EQ(3, c.num_inputs());
   EXPECT_EQ(1, c.num_outputs());
@@ -1460,7 +1459,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) {
 
 TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) {
   NodeDef def;
-  InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+  InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
                      {S({5, 3}), S({5}), S({-1})}, {}, {}, {});
   EXPECT_EQ(3, c.num_inputs());
   EXPECT_EQ(1, c.num_outputs());
@@ -1473,7 +1472,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) {
 
 TEST(CommonShapeFnsTest, ValidateSparseTensor) {
   NodeDef def;
-  InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+  InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
                      {S({5, 3}), S({5}), S({3})}, {}, {}, {});
   EXPECT_EQ(3, c.num_inputs());
   EXPECT_EQ(1, c.num_outputs());
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 064cc752291..5630898da6b 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -31,15 +31,14 @@ constexpr int64 InferenceContext::kUnknownDim;
 
 // Same as above, but with PartialTensorShape instead of TensorShapeProto
 InferenceContext::InferenceContext(
-    int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
+    int graph_def_version, const NodeDef& node_def, const OpDef& op_def,
     const std::vector<PartialTensorShape>& input_shapes,
     const std::vector<const Tensor*>& input_tensors,
     const std::vector<PartialTensorShape>& input_tensors_as_shapes,
     const std::vector<
         std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>&
         input_handle_shapes_and_types)
-    : graph_def_version_(graph_def_version),
-      node_def_(CHECK_NOTNULL(node_def)) {
+    : graph_def_version_(graph_def_version), node_def_(node_def) {
   std::vector<ShapeHandle> input_tensors_as_shape_handles;
   input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size());
   for (const PartialTensorShape& p : input_tensors_as_shapes) {
@@ -84,14 +83,13 @@ InferenceContext::InferenceContext(
 }
 
 InferenceContext::InferenceContext(
-    int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
+    int graph_def_version, const NodeDef& node_def, const OpDef& op_def,
     const std::vector<ShapeHandle>& input_shapes,
     const std::vector<const Tensor*>& input_tensors,
     const std::vector<ShapeHandle>& input_tensors_as_shapes,
     std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
         input_handle_shapes_and_types)
-    : graph_def_version_(graph_def_version),
-      node_def_(CHECK_NOTNULL(node_def)) {
+    : graph_def_version_(graph_def_version), node_def_(node_def) {
   PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
   if (!construction_status_.ok()) return;
   inputs_ = input_shapes;
@@ -112,7 +110,7 @@ Status InferenceContext::Run(
 #ifndef NDEBUG
   for (int i = 0; i < num_outputs(); ++i) {
     DCHECK(output(i).IsSet())
-        << i << " for " << node_def_->name() << " of type " << node_def_->op();
+        << i << " for " << node_def_.name() << " of type " << node_def_.op();
   }
 #endif  // NDEBUG
   return s;
@@ -171,8 +169,8 @@ void InferenceContext::PreInputInit(
   input_tensors_ = input_tensors;
   input_tensors_as_shapes_ = input_tensors_as_shapes;
 
-  construction_status_ = NameRangesForNode(*node_def_, op_def, &input_name_map_,
-                                           &output_name_map_);
+  construction_status_ =
+      NameRangesForNode(node_def_, op_def, &input_name_map_, &output_name_map_);
   if (!construction_status_.ok()) return;
 
   int num_outputs = 0;
@@ -290,7 +288,7 @@ string InferenceContext::DebugString(DimensionHandle d) {
 
 string InferenceContext::DebugString() const {
   return strings::StrCat("InferenceContext for node: ",
-                         node_def_->DebugString());
+                         node_def_.DebugString());
 }
 
 string InferenceContext::DebugString(const ShapeAndType& shape_and_type) {
@@ -1119,7 +1117,7 @@ Status InferenceContext::AttachContext(const Status& status) {
   }
 
   string error_context = strings::StrCat(
-      " for '", node_def_->name(), "' (op: '", node_def_->op(),
+      " for '", node_def_.name(), "' (op: '", node_def_.op(),
       "') with input shapes: ", absl::StrJoin(input_shapes, ", "));
   if (!input_from_tensors_str.empty()) {
     strings::StrAppend(&error_context, " and with computed input tensors: ",
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index a654f595e23..a140657eb15 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -161,9 +161,7 @@ class InferenceContext {
   // known from analysis of the graph.
   // <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
   // Values of <input_tensors_as_shapes> do not need to outlive the context.
-  //
-  // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
-  InferenceContext(int graph_def_version, const NodeDef* node_def,
+  InferenceContext(int graph_def_version, const NodeDef& node_def,
                    const OpDef& op_def,
                    const std::vector<ShapeHandle>& input_shapes,
                    const std::vector<const Tensor*>& input_tensors,
@@ -179,11 +177,8 @@ class InferenceContext {
   // partially known from analysis of the graph. <input_tensors_as_shapes>
   // can have fewer elements than <input_shapes>. Values of
   // <input_tensors_as_shapes> do not need to outlive the context.
-  //
-  // REQUIRES: <node_def> is not NULL, and must outlive the
-  // InferenceContext.
   InferenceContext(
-      int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
+      int graph_def_version, const NodeDef& node_def, const OpDef& op_def,
       const std::vector<PartialTensorShape>& input_shapes,
       const std::vector<const Tensor*>& input_tensors,
       const std::vector<PartialTensorShape>& input_tensors_as_shapes,
@@ -306,7 +301,7 @@ class InferenceContext {
   Status output(StringPiece output_name,
                 std::vector<ShapeHandle>* output) const;
 
-  AttrSlice attrs() const { return AttrSlice(*node_def_); }
+  AttrSlice attrs() const { return AttrSlice(node_def_); }
 
   // idx can be negative for an offset from end of dimensions.
   // idx must be in the range [-1 * s.rank, s.rank).
@@ -737,7 +732,7 @@ class InferenceContext {
       output_handle_shapes_and_types_;
 
   const int graph_def_version_;
-  const NodeDef* node_def_;
+  const NodeDef& node_def_;
   NameRangeMap input_name_map_;
   NameRangeMap output_name_map_;
 
@@ -784,7 +779,7 @@ inline DimensionOrConstant::DimensionOrConstant(int64 val) : val(val) {
 
 template <class T>
 Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
-  return GetNodeAttr(*node_def_, attr_name, value);
+  return GetNodeAttr(node_def_, attr_name, value);
 }
 
 }  // namespace shape_inference
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index 00bd71a868c..08fec604e2d 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -77,7 +77,7 @@ TEST_F(ShapeInferenceTest, InputOutputByName) {
                .Attr("N", 3)
                .Input(FakeInput(DT_FLOAT))
                .Finalize(&def);
-  InferenceContext c(kVersion, &def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})},
+  InferenceContext c(kVersion, def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})},
                      {}, {}, {});
 
   EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0))));
@@ -114,7 +114,7 @@ static OpDef MakeOpDef(int num_inputs, int num_outputs) {
 
 TEST_F(ShapeInferenceTest, DimensionOrConstant) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {});
   EXPECT_EQ(InferenceContext::kUnknownDim,
             c.Value(InferenceContext::kUnknownDim));
   EXPECT_EQ(1, c.Value(1));
@@ -129,7 +129,7 @@ TEST_F(ShapeInferenceTest, Run) {
   NodeDef def;
   def.set_name("foo");
   def.set_op("foo_op");
-  InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1})}, {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1})}, {}, {}, {});
   TF_ASSERT_OK(c.construction_status());
 
   {
@@ -167,7 +167,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
   def.set_op("foo_op");
   // Error when no constant tensors were requested.
   {
-    InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {},
+    InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {},
                        {});
     TF_ASSERT_OK(c.construction_status());
     auto fn = [](InferenceContext* c) {
@@ -186,7 +186,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
   {
     Tensor input_t =
         ::tensorflow::test::AsTensor<float>({1.1, 2.2, 3.3, 4.4, 5.5});
-    InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
+    InferenceContext c(kVersion, def, MakeOpDef(2, 2),
                        {S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {});
     TF_ASSERT_OK(c.construction_status());
     auto fn = [](InferenceContext* c) {
@@ -208,7 +208,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
   // shapes provided.
   {
     Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
-    InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})},
+    InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({3}), S({4})},
                        {nullptr, &input_t}, {}, {});
     TF_ASSERT_OK(c.construction_status());
     auto fn = [](InferenceContext* c) {
@@ -231,7 +231,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
   // shape was provided.
   {
     Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
-    InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})},
+    InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({3}), S({4})},
                        {nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {});
     TF_ASSERT_OK(c.construction_status());
     auto fn = [](InferenceContext* c) {
@@ -254,7 +254,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
 
 TEST_F(ShapeInferenceTest, RankAndDimInspection) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
+  InferenceContext c(kVersion, def, MakeOpDef(3, 2),
                      {Unknown(), S({1, -1, 3}), S({})}, {}, {}, {});
   EXPECT_EQ(3, c.num_inputs());
   EXPECT_EQ(2, c.num_outputs());
@@ -295,7 +295,7 @@ TEST_F(ShapeInferenceTest, RankAndDimInspection) {
 
 TEST_F(ShapeInferenceTest, NumElements) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
+  InferenceContext c(kVersion, def, MakeOpDef(3, 2),
                      {Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {});
 
   EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0))));
@@ -309,8 +309,8 @@ TEST_F(ShapeInferenceTest, NumElements) {
 
 TEST_F(ShapeInferenceTest, WithRank) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
-                     {Unknown(), S({1, -1, 3})}, {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
+                     {}, {}, {});
 
   auto in0 = c.input(0);
   auto in1 = c.input(1);
@@ -348,8 +348,8 @@ TEST_F(ShapeInferenceTest, WithRank) {
 
 TEST_F(ShapeInferenceTest, WithRankAtMost) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
-                     {Unknown(), S({1, -1, 3})}, {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
+                     {}, {}, {});
 
   auto in0 = c.input(0);
   auto in1 = c.input(1);
@@ -386,8 +386,8 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) {
 
 TEST_F(ShapeInferenceTest, WithRankAtLeast) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
-                     {Unknown(), S({1, -1, 3})}, {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
+                     {}, {}, {});
 
   auto in0 = c.input(0);
   auto in1 = c.input(1);
@@ -424,7 +424,7 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) {
 
 TEST_F(ShapeInferenceTest, WithValue) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {});
 
   auto d0 = c.Dim(c.input(0), 0);
   auto d1 = c.Dim(c.input(0), 1);
@@ -467,8 +467,8 @@ TEST_F(ShapeInferenceTest, WithValue) {
 
 TEST_F(ShapeInferenceTest, MergeDim) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})},
-                     {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {},
+                     {}, {});
 
   auto d2 = c.Dim(c.input(0), 0);
   auto d_unknown = c.Dim(c.input(0), 1);
@@ -530,7 +530,7 @@ TEST_F(ShapeInferenceTest, MergeDim) {
 
 TEST_F(ShapeInferenceTest, RelaxDim) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(1, 2),
+  InferenceContext c(kVersion, def, MakeOpDef(1, 2),
                      {S({2, InferenceContext::kUnknownDim, 2, 1,
                          InferenceContext::kUnknownDim})},
                      {}, {}, {});
@@ -578,7 +578,7 @@ TEST_F(ShapeInferenceTest, RelaxDim) {
 TEST_F(ShapeInferenceTest, RelaxShape) {
   NodeDef def;
   InferenceContext c(
-      kVersion, &def, MakeOpDef(7, 2),
+      kVersion, def, MakeOpDef(7, 2),
       {Unknown(), S({1, 2}), S({InferenceContext::kUnknownDim, 2}),
        S({1, InferenceContext::kUnknownDim}), S({1, 3}), Unknown(), S({1})},
       {}, {}, {});
@@ -647,7 +647,7 @@ TEST_F(ShapeInferenceTest, RelaxShape) {
 
 TEST_F(ShapeInferenceTest, MergeShape) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(7, 2),
+  InferenceContext c(kVersion, def, MakeOpDef(7, 2),
                      {Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}),
                       Unknown(), S({1})},
                      {}, {}, {});
@@ -753,7 +753,7 @@ TEST_F(ShapeInferenceTest, MergeShape) {
 
 TEST_F(ShapeInferenceTest, MergePrefix) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(4, 2),
+  InferenceContext c(kVersion, def, MakeOpDef(4, 2),
                      {
                          Unknown(),
                          S({-1, 2}),
@@ -808,7 +808,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
 
 TEST_F(ShapeInferenceTest, Subshape) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
+  InferenceContext c(kVersion, def, MakeOpDef(2, 2),
                      {S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {});
 
   ShapeHandle unknown = c.input(1);
@@ -880,7 +880,7 @@ TEST_F(ShapeInferenceTest, Subshape) {
 
 TEST_F(ShapeInferenceTest, Concatenate) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
+  InferenceContext c(kVersion, def, MakeOpDef(3, 2),
                      {S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {});
 
   auto in0 = c.input(0);
@@ -907,7 +907,7 @@ TEST_F(ShapeInferenceTest, Concatenate) {
 
 TEST_F(ShapeInferenceTest, ReplaceDim) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()},
+  InferenceContext c(kVersion, def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()},
                      {}, {}, {});
 
   auto in = c.input(0);
@@ -939,7 +939,7 @@ TEST_F(ShapeInferenceTest, ReplaceDim) {
 
 TEST_F(ShapeInferenceTest, MakeShape) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {},
+  InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {},
                      {}, {});
 
   std::vector<DimensionHandle> dims;
@@ -966,7 +966,7 @@ TEST_F(ShapeInferenceTest, MakeShape) {
 TEST_F(ShapeInferenceTest, UnknownShape) {
   NodeDef def;
   std::vector<ShapeHandle> empty;
-  InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
 
   auto u0 = c.UnknownShape();
   auto u1 = c.UnknownShape();
@@ -978,7 +978,7 @@ TEST_F(ShapeInferenceTest, UnknownShape) {
 TEST_F(ShapeInferenceTest, KnownShapeToProto) {
   NodeDef def;
   std::vector<ShapeHandle> empty;
-  InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
 
   auto s = c.MakeShape({1, 2, 3});
   TensorShapeProto proto;
@@ -992,7 +992,7 @@ TEST_F(ShapeInferenceTest, KnownShapeToProto) {
 TEST_F(ShapeInferenceTest, UnknownShapeToProto) {
   NodeDef def;
   std::vector<ShapeHandle> empty;
-  InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
 
   auto u0 = c.UnknownShape();
   TensorShapeProto proto;
@@ -1005,7 +1005,7 @@ TEST_F(ShapeInferenceTest, UnknownShapeToProto) {
 TEST_F(ShapeInferenceTest, Scalar) {
   NodeDef def;
   std::vector<ShapeHandle> empty;
-  InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
 
   auto s0 = c.Scalar();
   EXPECT_EQ("[]", c.DebugString(s0));
@@ -1016,7 +1016,7 @@ TEST_F(ShapeInferenceTest, Scalar) {
 TEST_F(ShapeInferenceTest, Vector) {
   NodeDef def;
   std::vector<ShapeHandle> empty;
-  InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
 
   auto s0 = c.Vector(1);
   EXPECT_EQ("[1]", c.DebugString(s0));
@@ -1032,7 +1032,7 @@ TEST_F(ShapeInferenceTest, Vector) {
 TEST_F(ShapeInferenceTest, Matrix) {
   NodeDef def;
   std::vector<ShapeHandle> empty;
-  InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
 
   auto s0 = c.Matrix(1, 2);
   EXPECT_EQ("[1,2]", c.DebugString(s0));
@@ -1054,7 +1054,7 @@ TEST_F(ShapeInferenceTest, Matrix) {
 TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
   auto create = [&](Tensor* t) {
     NodeDef def;
-    InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {Unknown()}, {t}, {},
+    InferenceContext c(kVersion, def, MakeOpDef(1, 0), {Unknown()}, {t}, {},
                        {});
     ShapeHandle out;
     Status s = c.MakeShapeFromShapeTensor(0, &out);
@@ -1115,7 +1115,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
   // Test when the input shape is wrong.
   {
     NodeDef def;
-    InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr},
+    InferenceContext c(kVersion, def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr},
                        {}, {});
     ShapeHandle out;
     EXPECT_EQ("Shape must be rank 1 but is rank 2",
@@ -1126,7 +1126,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
 TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) {
   NodeDef def;
   std::vector<ShapeHandle> empty;
-  InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
 
   // With an unknown rank.
   ShapeHandle out;
@@ -1145,7 +1145,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) {
 TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) {
   NodeDef def;
   std::vector<ShapeHandle> empty;
-  InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
 
   ShapeHandle out;
   TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape(), &out));
@@ -1159,7 +1159,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) {
 TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
   NodeDef def;
   std::vector<ShapeHandle> empty;
-  InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
   TensorShapeProto proto;
 
   // With a set unknown rank.
@@ -1195,7 +1195,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
 TEST_F(ShapeInferenceTest, MakeDim) {
   NodeDef def;
   std::vector<ShapeHandle> empty;
-  InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
 
   auto d0 = c.MakeDim(1);
   auto d1 = c.MakeDim(1);
@@ -1209,7 +1209,7 @@ TEST_F(ShapeInferenceTest, MakeDim) {
 TEST_F(ShapeInferenceTest, UnknownDim) {
   NodeDef def;
   std::vector<ShapeHandle> empty;
-  InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
 
   auto d0 = c.UnknownDim();
   auto d1 = c.UnknownDim();
@@ -1221,7 +1221,7 @@ TEST_F(ShapeInferenceTest, UnknownDim) {
 TEST_F(ShapeInferenceTest, UnknownShapeOfRank) {
   NodeDef def;
   std::vector<ShapeHandle> empty;
-  InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
 
   auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3);
   EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3));
@@ -1234,7 +1234,7 @@ TEST_F(ShapeInferenceTest, InputTensors) {
   const Tensor t1 = tensorflow::test::AsTensor<float>({10});
   const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
+  InferenceContext c(kVersion, def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
                      {&t1, &t2}, {}, {});
 
   EXPECT_TRUE(c.input_tensor(0) == &t1);
@@ -1246,8 +1246,8 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
   Tensor t1 = tensorflow::test::AsScalar<int32>(20);
   Tensor t2 = tensorflow::test::AsScalar<int32>(-1);
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})},
-                     {&t1, &t2}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2},
+                     {}, {});
 
   DimensionHandle d;
   EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
@@ -1280,7 +1280,7 @@ TEST_F(ShapeInferenceTest, GetAttr) {
             .ok());
 
   std::vector<ShapeHandle> empty;
-  InferenceContext c(kVersion, &def, op_reg_data.op_def, empty, {}, {}, {});
+  InferenceContext c(kVersion, def, op_reg_data.op_def, empty, {}, {}, {});
   string value;
   EXPECT_TRUE(c.GetAttr("foo", &value).ok());
   EXPECT_EQ("bar", value);
@@ -1288,7 +1288,7 @@ TEST_F(ShapeInferenceTest, GetAttr) {
 
 TEST_F(ShapeInferenceTest, Divide) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {},
+  InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {},
                      {}, {});
 
   auto s = c.input(0);
@@ -1351,7 +1351,7 @@ TEST_F(ShapeInferenceTest, Divide) {
 
 TEST_F(ShapeInferenceTest, Add) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {},
+  InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {},
                      {});
 
   auto s = c.input(0);
@@ -1401,8 +1401,8 @@ TEST_F(ShapeInferenceTest, Add) {
 
 TEST_F(ShapeInferenceTest, Subtract) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {},
-                     {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {},
+                     {});
 
   auto s = c.input(0);
   auto d_6 = c.Dim(s, 0);
@@ -1451,8 +1451,8 @@ TEST_F(ShapeInferenceTest, Subtract) {
 
 TEST_F(ShapeInferenceTest, Multiply) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {},
-                     {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {},
+                     {});
 
   auto s = c.input(0);
   auto d_6 = c.Dim(s, 0);
@@ -1505,7 +1505,7 @@ TEST_F(ShapeInferenceTest, Multiply) {
 TEST_F(ShapeInferenceTest, FullyDefined) {
   NodeDef def;
   std::vector<ShapeHandle> empty;
-  InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
 
   // No rank or missing dimension information should return false.
   EXPECT_FALSE(c.FullyDefined(c.UnknownShape()));
@@ -1518,8 +1518,8 @@ TEST_F(ShapeInferenceTest, FullyDefined) {
 
 TEST_F(ShapeInferenceTest, Min) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {},
-                     {}, {});
+  InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {},
+                     {});
 
   auto s = c.input(0);
   auto d_1 = c.Dim(s, 0);
@@ -1567,7 +1567,7 @@ TEST_F(ShapeInferenceTest, Min) {
 
 TEST_F(ShapeInferenceTest, Max) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {},
+  InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {},
                      {});
 
   auto s = c.input(0);
@@ -1605,7 +1605,7 @@ TEST_F(ShapeInferenceTest, Max) {
 
 void ShapeInferenceTest::TestMergeHandles(bool input_not_output) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
+  InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
                      {});
   auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
     ShapeHandle s;
@@ -1716,7 +1716,7 @@ TEST_F(ShapeInferenceTest, MergeOutputHandleShapesAndTypes) {
 
 void ShapeInferenceTest::TestRelaxHandles(bool input_not_output) {
   NodeDef def;
-  InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
+  InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
                      {});
   auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
     ShapeHandle s;
diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc
index 2cf447471e3..3145a6e5954 100644
--- a/tensorflow/core/framework/shape_inference_testutil.cc
+++ b/tensorflow/core/framework/shape_inference_testutil.cc
@@ -60,7 +60,7 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
     }
   }
   shape_inference::InferenceContext c(
-      op.graph_def_version, &op.node_def, op_reg_data->op_def, in_shapes,
+      op.graph_def_version, op.node_def, op_reg_data->op_def, in_shapes,
       op.input_tensors, {}, std::move(input_resource_handle_shapes_and_types));
   TF_RETURN_IF_ERROR(c.construction_status());
   if (op_reg_data->shape_inference_fn == nullptr) {
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 8e6e72993f8..925ea44d454 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -1168,7 +1168,7 @@ class SymbolicShapeRefiner {
     std::vector<ShapeHandle> input_tensors_as_shapes;
 
     node_ctx.inference_context.reset(new InferenceContext(
-        graph_def_version_, node, node_ctx.op_data->op_def, input_shapes,
+        graph_def_version_, *node, node_ctx.op_data->op_def, input_shapes,
         input_tensors, input_tensors_as_shapes,
         std::move(input_handle_shapes_and_types)));
     const Status s = node_ctx.inference_context->construction_status();
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index 4cda29ee2a9..718a34c07e6 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -198,7 +198,7 @@ TEST(ArrayOpsTest, Identity_ShapeFnHandles) {
       new std::vector<std::pair<PartialTensorShape, DataType>>(
           {{PartialTensorShape(), DT_BOOL}}));
   shape_inference::InferenceContext c(
-      TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def,
+      TF_GRAPH_DEF_VERSION, op.node_def, op_reg_data->op_def,
       {PartialTensorShape()}, {}, {}, handle_data);
   TF_ASSERT_OK(c.construction_status());
   ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index 33a010b9349..7ebd7889a35 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -229,7 +229,7 @@ TEST(MathOpsTest, Select_ShapeFn) {
   auto run_inference_for_handles = [&]() -> Status {
     CHECK(op_reg_data->shape_inference_fn != nullptr);
     c.reset(new shape_inference::InferenceContext(
-        TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def,
+        TF_GRAPH_DEF_VERSION, op.node_def, op_reg_data->op_def,
         {PartialTensorShape(), PartialTensorShape(), PartialTensorShape()}, {},
         {}, handle_data));
     TF_CHECK_OK(c->construction_status());