diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index c67007dca0a..ba5a9268b4f 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -46,6 +46,33 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
 void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
                 TF_Status* status) {
   mutex_lock l(graph->mu);
+  tensorflow::shape_inference::InferenceContext* ic =
+      graph->refiner.GetContext(&new_src.oper->node);
+
+  if (ic->num_outputs() <= new_src.index) {
+    status->status = tensorflow::errors::OutOfRange(
+        "Cannot update edge. Output index [", new_src.index,
+        "] is greater than the number of total outputs [", ic->num_outputs(),
+        "].");
+    return;
+  }
+  tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
+
+  tensorflow::shape_inference::InferenceContext* ic_dst =
+      graph->refiner.GetContext(&dst.oper->node);
+  if (ic_dst->num_inputs() <= dst.index) {
+    status->status = tensorflow::errors::OutOfRange(
+        "Cannot update edge. Input index [", dst.index,
+        "] is greater than the number of total inputs [", ic_dst->num_inputs(),
+        "].");
+    return;
+  }
+  if (!ic_dst->MergeInput(dst.index, shape)) {
+    status->status = tensorflow::errors::InvalidArgument(
+        "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
+        " and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
+    return;
+  }
   status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
                                            &dst.oper->node, dst.index);
 }
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc
index 1ed5eb3f228..8e314c7ea57 100644
--- a/tensorflow/core/common_runtime/shape_refiner.cc
+++ b/tensorflow/core/common_runtime/shape_refiner.cc
@@ -333,7 +333,8 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
     InferenceContext* c = iter->second->get_context();
     DCHECK_GE(dst_input, 0);
     ShapeHandle existing_input = node_context->input(dst_input);
-    if (!relax && node_context->MergeInput(dst_input, c->output(src_output))) {
+    if (!relax && node_context->MergeInput(dst_input, c->output(src_output)) &&
+        !existing_input.SameHandle(node_context->input(dst_input))) {
       *refined = true;
     } else if (relax) {
       if (node_context->RelaxInput(dst_input, c->output(src_output))) {
diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc
index 676fc7ccedf..ff32e855d59 100644
--- a/tensorflow/core/common_runtime/shape_refiner_test.cc
+++ b/tensorflow/core/common_runtime/shape_refiner_test.cc
@@ -1259,7 +1259,17 @@ TEST_F(ShapeRefinerTest, IncrementalUpdates) {
   EXPECT_FALSE(refined);
   ctx = m.GetContext(dequeue);
   EXPECT_EQ("[?,7]", ctx->DebugString(ctx->output(0)));
-  ASSERT_FALSE(SameHandle(ctx->Dim(ctx->output(0), 0), ctx->Dim(shp, 0)));
+  EXPECT_FALSE(SameHandle(ctx->Dim(ctx->output(0), 0), ctx->Dim(shp, 0)));
+
+  // Inject a shape of the same handle and expect refined to not change.
+  ctx = m.GetContext(queue);
+  shape_inference::ShapeHandle shp2 = shp;
+  ctx->set_output_handle_shapes_and_types(
+      0, std::vector<shape_inference::ShapeAndType>{{shp2, DT_FLOAT}});
+  refined = false;
+  TF_ASSERT_OK(m.UpdateNode(dequeue, /*relax=*/false, &refined));
+  EXPECT_FALSE(refined);
+  EXPECT_TRUE(SameHandle(ctx->Dim(shp, 0), ctx->Dim(shp2, 0)));
 }
 
 void TestSimpleFunctionInference(bool enable_function_inference,
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 5d6bf559bb3..fe0742e1db5 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -544,9 +544,10 @@ Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
       return_s1 = false;
     } else if (v0 != v1) {
       *out = nullptr;
-      return errors::InvalidArgument("Dimension ", i,
-                                     " in both shapes must be equal, but are ",
-                                     Value(d0), " and ", Value(d1));
+      return errors::InvalidArgument(
+          "Dimension ", i, " in both shapes must be equal, but are ", Value(d0),
+          " and ", Value(d1), ". Shapes are ", DebugString(s0), " and ",
+          DebugString(s1), ".");
     }
   }
 
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 485980e42ee..b12d37b4c03 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -237,24 +237,19 @@ class InferenceContext {
   // - For any one dimension, if the values for that dimension in both shapes
   //   are known, then the values must match.
   // - If one shape has equal or more information than the other shape in every
-  //   dimension, the shape with more information will be returned. Otherwise a
-  //   new shape holding the combined information of the input shapes will be
-  //   returned.
+  //   dimension, the new shape will become the shape with more information.
   // - Example: merging [2,?] and [?,2] results in [2,2]
   // - Example: [2,2] cannot be merged with [1,2]
   //
   // This requires idx to be in the [0, num_inputs) range. If the merge is
-  // successful and the new shape differs from the old one, store the new shape
-  // and return true. Return false otherwise.
+  // successful, return true. Return false otherwise.
   bool MergeInput(int idx, ShapeHandle shape) {
     ShapeHandle new_shape;
-    if (!Merge(inputs_[idx], shape, &new_shape).ok() ||
-        inputs_[idx].SameHandle(new_shape)) {
-      return false;
-    }
+    if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false;
     inputs_[idx] = new_shape;
     return true;
   }
+
   // Relax the stored shape of the input in position idx with <shape> according
   // to the following rules:
   //
diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc
index e5d57facaa7..7686cef2195 100644
--- a/tensorflow/core/graph/graph_test.cc
+++ b/tensorflow/core/graph/graph_test.cc
@@ -511,6 +511,13 @@ TEST_F(GraphTest, UpdateEdge) {
   EXPECT_EQ(
       s.error_message(),
       "Node 'A' (type: 'OneOutput', num of outputs: 1) does not have output 1");
+
+  // Update a's 1st input which is out of range.
+  s = graph_.UpdateEdge(c, 0, a, 0);
+  EXPECT_FALSE(s.ok());
+  EXPECT_EQ(
+      s.error_message(),
+      "Node 'A' (type: 'OneOutput', num of inputs: 0) does not have input 0");
 }
 
 TEST_F(GraphTest, InputEdges) {
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index d2608845ac7..b0abbfc7dcc 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1805,7 +1805,7 @@ class Operation(object):
     tensor._add_consumer(self)  # pylint: disable=protected-access
     self._recompute_node_def()
 
-  def _update_input(self, index, tensor, dtype=None):
+  def _update_input(self, index, tensor):
     """Update the input to this operation at the given index.
 
     NOTE: This is for TF internal use only. Please don't use it.
@@ -1813,8 +1813,6 @@ class Operation(object):
     Args:
       index: the index of the input to update.
       tensor: the Tensor to be used as the input at the given index.
-      dtype: tf.DType: type of the input; defaults to
-        the tensor's dtype.
 
     Raises:
       TypeError: if tensor is not a Tensor,
@@ -1832,17 +1830,9 @@ class Operation(object):
             self._tf_input(index),
             status)
     else:
-      if dtype is None:
-        dtype = tensor.dtype
-      else:
-        dtype = dtypes.as_dtype(dtype)
-        if not dtype.is_compatible_with(tensor.dtype):
-          raise TypeError(
-              "Cannot convert a tensor of type %s to an input of type %s" %
-              (tensor.dtype.name, dtype.name))
       self._inputs[index].consumers().remove(self)
       self._inputs[index] = tensor
-      self._input_types_val[index] = dtype
+      self._input_types_val[index] = tensor.dtype
       tensor._add_consumer(self)  # pylint: disable=protected-access
       self._recompute_node_def()
 
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 4e931e00c59..1be306ddc59 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -492,8 +492,6 @@ class OperationTest(test_util.TensorFlowTestCase):
       with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
         z.op._update_input(0, x)  # pylint: disable=protected-access
 
-  # TODO(nolivia): check the shape/type in _update_input() instead of depending
-  # on run to do that.
   def testUpdateInputTypeError(self):
     g = ops.Graph()
     with g.as_default():
@@ -509,6 +507,37 @@ class OperationTest(test_util.TensorFlowTestCase):
           "with expected int32"):
         sess.run(z)
 
+  def testUpdateInputShapeError(self):
+    # C-API throws the error differently.
+    if ops._USE_C_API:
+      return
+    g = ops.Graph()
+    with g.as_default():
+      w = constant_op.constant(2, shape=[3, 1])
+      x = constant_op.constant(0, shape=[3, 1])
+      y = constant_op.constant(1, shape=[2, 2])
+      z = w + x
+      z.op._update_input(0, y)  # pylint: disable=protected-access
+
+    with session.Session(graph=g) as sess:
+      with self.assertRaisesRegexp(errors.InvalidArgumentError,
+                                   r"Incompatible shapes: \[2,2\] vs. \[3,1\]"):
+        sess.run(z)
+
+  def testUpdateInputShapeErrorC(self):
+    if not ops._USE_C_API:
+      return
+    g = ops.Graph()
+    with g.as_default():
+      w = constant_op.constant(2, shape=[3, 1])
+      x = constant_op.constant(0, shape=[3, 1])
+      y = constant_op.constant(1, shape=[2, 2])
+      z = w + x
+    with self.assertRaisesRegexp(
+        errors.InvalidArgumentError,
+        r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"):
+      z.op._update_input(0, y)  # pylint: disable=protected-access
+
   def testUpdateInputOutOfRange(self):
     # C-API throws the error differently.
     if ops._USE_C_API: return
@@ -524,9 +553,11 @@ class OperationTest(test_util.TensorFlowTestCase):
     g = ops.Graph()
     with g.as_default():
       x = constant_op.constant(1)
-    with self.assertRaisesRegexp(errors.OutOfRangeError,
-                                 r"Node 'Const' \(type: 'Const', "
-                                 r"num of inputs: 0\) does not have input 1"):
+    with self.assertRaisesRegexp(
+        errors.OutOfRangeError,
+        r"Cannot update edge. Input index \[1\] is greater than the number of "
+        r"total inputs \[0\]."
+    ):
       x.op._update_input(1, x)  # pylint: disable=protected-access
 
   def testOpDef(self):