From 7946b8481f73cf6d91375400237b7bcce545e7be Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Tue, 6 Sep 2016 10:47:06 -0800
Subject: [PATCH] Delegate to C++ shape functions for python ops in
 training_ops and io_ops. Change: 132345086

---
 tensorflow/core/ops/training_ops.cc        |   9 +-
 tensorflow/core/ops/training_ops_test.cc   |  65 +++--
 tensorflow/python/ops/io_ops.py            | 149 ++----------
 tensorflow/python/training/training_ops.py | 264 ++-------------------
 4 files changed, 74 insertions(+), 413 deletions(-)

diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc
index ab82617a136..2c6a9cde91b 100644
--- a/tensorflow/core/ops/training_ops.cc
+++ b/tensorflow/core/ops/training_ops.cc
@@ -38,10 +38,11 @@ static Status HandleGradAndIndicesInputs(InferenceContext* c, bool sparse,
   DimensionHandle unused;
   TF_RETURN_IF_ERROR(c->Merge(c->Dim(indices, 0), c->Dim(grad, 0), &unused));
 
-  // Trailing part of grad matches *s.
-  ShapeHandle grad_subshape;
-  TF_RETURN_IF_ERROR(c->Subshape(grad, 1, &grad_subshape));
-  TF_RETURN_IF_ERROR(c->Merge(*s, grad_subshape, s));
+  // Trailing part of grad matches trailing part of *s.
+  ShapeHandle grad_unknown_first;
+  TF_RETURN_IF_ERROR(
+      c->ReplaceDim(grad, 0, c->UnknownDim(), &grad_unknown_first));
+  TF_RETURN_IF_ERROR(c->Merge(*s, grad_unknown_first, s));
 
   return Status::OK();
 }
diff --git a/tensorflow/core/ops/training_ops_test.cc b/tensorflow/core/ops/training_ops_test.cc
index d1023a1e73d..9c3489211c8 100644
--- a/tensorflow/core/ops/training_ops_test.cc
+++ b/tensorflow/core/ops/training_ops_test.cc
@@ -30,15 +30,14 @@ static void TestGradAndIndicesErrorHandling(ShapeInferenceTestOp op,
                            grad_indices_spec, shape_spec_end);
   };
 
-  // mismatch between grad[1] and var[0].
-  INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
-              shape_spec("[1]", "[?,2];[?]").c_str());
+  // mismatch between grad[1] and var[1].
+  INFER_ERROR("Dimension 1 in both shapes must be equal", op,
+              shape_spec("[?,1]", "[?,2];[?]").c_str());
   // grad[0] and indices[0] must match.
   INFER_ERROR("Dimensions must be equal, but are 1 and 2", op,
               shape_spec("?", "[2,?];[1]").c_str());
   // grad is wrong rank.
-  INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op,
-              shape_spec("[1]", "[2];[?]").c_str());
+  INFER_ERROR("must be equal rank", op, shape_spec("[1]", "[?,2];[?]").c_str());
   // indices is wrong rank.
   INFER_ERROR("Shape must be rank 1 but is rank 2", op,
               shape_spec("[?]", "[?];[1,2]").c_str());
@@ -74,7 +73,7 @@ TEST(TrainingOpsTest, SparseApplyProximalGradientDescent_ShapeFn) {
   ShapeInferenceTestOp op("SparseApplyProximalGradientDescent");
 
   // Output is a merge of inputs 0 (var) and the non-indices part of 4 (delta).
-  INFER_OK(op, "[1,?];[];[];[];[?,?,2];[3]", "[d0_0,d4_2]");
+  INFER_OK(op, "[1,?];[];[];[];[?,2];[3]", "[d0_0,d4_1]");
 
   TestGradAndIndicesErrorHandling(op, "[];[];[]");
 
@@ -109,14 +108,14 @@ TEST(TrainingOpsTest, SparseApplyAdadelta_ShapeFn) {
 
   // Output is a merge of inputs 0, 1, 2, and non-indices part of 6 (var, accum,
   // accum_update, grad).
-  INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[];[];[];[?,?,?,?,4];?",
-           "[d0_0,d1_1,d2_2,d6_4]");
+  INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[];[];[];[?,?,?,4];?",
+           "[d0_0,d1_1,d2_2,d6_3]");
   INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
               "[1];[2];[1];[];[];[];[1];?");
   INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
               "[1];[1];[2];[];[];[];[1];?");
-  INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
-              "[1];[1];[1];[];[];[];[?,2];?");
+  INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
+              "[?,1];[?,1];[?,1];[];[];[];[?,2];?");
 
   TestGradAndIndicesErrorHandling(op, "?;?;?;?;?");
 
@@ -145,11 +144,11 @@ TEST(TrainingOpsTest, SparseApplyAdagrad_ShapeFn) {
 
   // Output is a merge of inputs 0, 1, and non-indices part of 3 (var, accum,
   // grad).
-  INFER_OK(op, "[1,?,?];[?,2,?];[];[?,?,?,3];?", "[d0_0,d1_1,d3_3]");
-  INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
-              "[1];[2];[];[1];?");
-  INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op,
-              "[1];[1];[];[2];?");
+  INFER_OK(op, "[1,?,?];[?,2,?];[];[?,?,3];?", "[d0_0,d1_1,d3_2]");
+  INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
+              "[?,1];[?,2];[];[?,1];?");
+  INFER_ERROR("Shapes must be equal rank, but are 2 and 3", op,
+              "[?,1];[?,1];[];[?,?,2];?");
 
   TestGradAndIndicesErrorHandling(op, "?;?");
 
@@ -178,11 +177,11 @@ TEST(TrainingOpsTest, SparseApplyProximalAdagrad_ShapeFn) {
 
   // Output is a merge of inputs 0, 1, and the non-indices part of 5 (var,
   // accum, grad).
-  INFER_OK(op, "[1,?,?];[?,2,?];[];[];[];[?,?,?,3];?", "[d0_0,d1_1,d5_3]");
+  INFER_OK(op, "[1,?,?];[?,2,?];[];[];[];[?,?,3];?", "[d0_0,d1_1,d5_2]");
   INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
               "[1];[2];[];[];[];[?,1];?");
-  INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
-              "[1];[1];[];[];[];[?,2];?");
+  INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
+              "[?,1];[?,1];[];[];[];[?,2];?");
 
   TestGradAndIndicesErrorHandling(op, "?;?;?;?");
 
@@ -217,14 +216,14 @@ TEST(TrainingOpsTest, SparseApplyFtrl_ShapeFn) {
 
   // Output is a merge of inputs 0, 1, 2, and non-indices part of 3 (var, accum,
   // linear, grad).
-  INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[?,?,?,?,4];?;[];[];[];[]",
-           "[d0_0,d1_1,d2_2,d3_4]");
+  INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[?,?,?,4];?;[];[];[];[]",
+           "[d0_0,d1_1,d2_2,d3_3]");
   INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
               "[1];[2];[1];[?,1];?;[];[];[];[]");
   INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
               "[1];[1];[2];[?,1];?;[];[];[];[]");
-  INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
-              "[1];[1];[1];[?,2];?;[];[];[];[]");
+  INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
+              "[?,1];[?,1];[?,1];[?,2];?;[];[];[];[]");
 
   TestGradAndIndicesErrorHandling(op, "?;?", ";?;?;?;?");
 
@@ -255,11 +254,11 @@ TEST(TrainingOpsTest, SparseApplyMomentum_ShapeFn) {
 
   // Output is a merge of inputs 0, 1, and non-indices part of 3 (var, accum,
   // grad).
-  INFER_OK(op, "[1,?,?];[?,2,?];[];[?,?,?,3];?;[]", "[d0_0,d1_1,d3_3]");
-  INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
-              "[1];[2];[];[?,1];?;[]");
-  INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
-              "[1];[1];[];[?,2];?;[]");
+  INFER_OK(op, "[1,?,?];[?,2,?];[];[?,?,3];?;[]", "[d0_0,d1_1,d3_2]");
+  INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
+              "[?,1];[?,2];[];[?,1];?;[]");
+  INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
+              "[?,1];[?,1];[];[?,2];?;[]");
 
   TestGradAndIndicesErrorHandling(op, "?;?", ";?");
 
@@ -316,14 +315,14 @@ TEST(TrainingOpsTest, SparseApplyRMSProp_ShapeFn) {
 
   // Output is a merge of inputs 0, 1, 2, and the non-indices part of 7 (var,
   // ms, mom, and grad).
-  INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[];[];[];[];[?,?,?,?,4];?",
-           "[d0_0,d1_1,d2_2,d7_4]");
+  INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[];[];[];[];[?,?,?,4];?",
+           "[d0_0,d1_1,d2_2,d7_3]");
   INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
-              "[1];[2];[1];[];[];[];[];[?,1];?");
+              "[1];[2];[1];[];[];[];[];[1];?");
   INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
-              "[1];[1];[2];[];[];[];[];[?,1];?");
-  INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
-              "[1];[1];[1];[];[];[];[];[?,2];?");
+              "[1];[1];[2];[];[];[];[];[1];?");
+  INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
+              "[?,1];[?,1];[?,1];[];[];[];[];[?,2];?");
 
   TestGradAndIndicesErrorHandling(op, "?;?;?;?;?;?");
 
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index 7990dba3b63..df34de3da41 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -138,7 +138,6 @@ from __future__ import print_function
 from tensorflow.python.framework import common_shapes
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
 from tensorflow.python.lib.io import python_io
 from tensorflow.python.ops import gen_io_ops
 # go/tf-wildcard-import
@@ -205,80 +204,12 @@ def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type,
       preferred_shard, name=name)
 
 
-@ops.RegisterShape("Restore")
-def _RestoreShape(op):
-  """Shape function for Restore op."""
-  # Validate input shapes.
-  unused_file_pattern = op.inputs[0].get_shape().merge_with(
-      tensor_shape.scalar())
-  unused_tensor_name = op.inputs[1].get_shape().merge_with(
-      tensor_shape.scalar())
-  return [tensor_shape.unknown_shape()]
-
-
-@ops.RegisterShape("RestoreSlice")
-def _RestoreSliceShape(op):
-  """Shape function for RestoreSlice op."""
-  # Validate input shapes.
-  unused_file_pattern = op.inputs[0].get_shape().merge_with(
-      tensor_shape.scalar())
-  unused_tensor_name = op.inputs[1].get_shape().merge_with(
-      tensor_shape.scalar())
-  unused_shape_and_slice_shape = op.inputs[2].get_shape().merge_with(
-      tensor_shape.scalar())
-  # TODO(mrry): Attempt to parse the shape_and_slice value and use it
-  # to form the shape of the output.
-  return [tensor_shape.unknown_shape()]
-
-
-@ops.RegisterShape("Save")
-def _SaveShape(op):
-  """Shape function for Save op."""
-  # Validate input shapes.
-  unused_filename = op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
-  data_count = len(op.inputs) - 2
-  unused_tensor_names_shape = op.inputs[1].get_shape().merge_with(
-      tensor_shape.vector(data_count))
-  return []
-
-
-@ops.RegisterShape("SaveSlices")
-def _SaveSlicesShape(op):
-  """Shape function for SaveSlices op."""
-  # Validate input shapes.
-  unused_filename = op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
-  data_count = len(op.inputs) - 3
-  unused_tensor_names_shape = op.inputs[1].get_shape().merge_with(
-      tensor_shape.vector(data_count))
-  unused_shapes_and_slices_shape = op.inputs[2].get_shape().merge_with(
-      tensor_shape.vector(data_count))
-  # TODO(mrry): Attempt to parse the shapes_and_slices values and use
-  # them to constrain the shape of the remaining inputs.
-  return []
-
-
-@ops.RegisterShape("ShardedFilename")
-def _ShardedFilenameShape(op):
-  """Shape function for ShardedFilename op."""
-  # Validate input shapes.
-  unused_basename_shape = op.inputs[0].get_shape().merge_with(
-      tensor_shape.scalar())
-  unused_shard_shape = op.inputs[1].get_shape().merge_with(
-      tensor_shape.scalar())
-  unused_num_shards_shape = op.inputs[2].get_shape().merge_with(
-      tensor_shape.scalar())
-  return [tensor_shape.scalar()]
-
-
-@ops.RegisterShape("ShardedFilespec")
-def _ShardedFilespecShape(op):
-  """Shape function for ShardedFilespec op."""
-  # Validate input shapes.
-  unused_basename_shape = op.inputs[0].get_shape().merge_with(
-      tensor_shape.scalar())
-  unused_num_shards_shape = op.inputs[1].get_shape().merge_with(
-      tensor_shape.scalar())
-  return [tensor_shape.scalar()]
+ops.RegisterShape("Restore")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("RestoreSlice")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("Save")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("SaveSlices")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ShardedFilename")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ShardedFilespec")(common_shapes.call_cpp_shape_fn)
 
 
 class ReaderBase(object):
@@ -574,61 +505,13 @@ ops.RegisterShape("TextLineReader")(common_shapes.scalar_shape)
 ops.RegisterShape("WholeFileReader")(common_shapes.scalar_shape)
 ops.RegisterShape("TFRecordReader")(common_shapes.scalar_shape)
 
-
-@ops.RegisterShape("ReaderNumRecordsProduced")
-@ops.RegisterShape("ReaderNumWorkUnitsCompleted")
-@ops.RegisterShape("ReaderSerializeState")
-def _ReaderScalarShape(op):
-  """Shape function for ops that transform a reader to a scalar."""
-  unused_handle_shape = op.inputs[0].get_shape().merge_with(
-      tensor_shape.scalar())
-  return [tensor_shape.scalar()]
-
-
-@ops.RegisterShape("ReaderRead")
-def _ReaderReadShape(op):
-  """Shape function for the ReaderBase.Read op."""
-  unused_handle_shape = op.inputs[0].get_shape().merge_with(
-      tensor_shape.scalar())
-  unused_queue_shape = op.inputs[1].get_shape().merge_with(
-      tensor_shape.scalar())
-  return [tensor_shape.scalar(), tensor_shape.scalar()]
-
-
-@ops.RegisterShape("ReaderReadUpTo")
-def _ReaderReadUpToShape(_):
-  """Shape function for the ReaderBase.ReadUpTo op."""
-  return [tensor_shape.unknown_shape(ndims=1),
-          tensor_shape.unknown_shape(ndims=1)]
-
-
-@ops.RegisterShape("ReaderReset")
-def _ReaderResetShape(op):
-  """Shape function for the ReaderBase.Reset op."""
-  unused_handle_shape = op.inputs[0].get_shape().merge_with(
-      tensor_shape.scalar())
-  return []
-
-
-@ops.RegisterShape("ReaderRestoreState")
-def _ReaderRestoreStateShape(op):
-  """Shape function for the ReaderBase.Restore op."""
-  unused_handle_shape = op.inputs[0].get_shape().merge_with(
-      tensor_shape.scalar())
-  unused_state_shape = op.inputs[1].get_shape().merge_with(
-      tensor_shape.scalar())
-  return []
-
-
-@ops.RegisterShape("ReadFile")
-def _ReadFileShape(op):
-  """Shape function for the ReadFile op."""
-  return [op.inputs[0].get_shape().merge_with(tensor_shape.scalar())]
-
-
-@ops.RegisterShape("MatchingFiles")
-def _MatchingFilesShape(op):
-  """Shape function for the MatchingFiles op."""
-  unused_patern_shape = op.inputs[0].get_shape().merge_with(
-      tensor_shape.scalar())
-  return [tensor_shape.unknown_shape(ndims=1)]
+ops.RegisterShape("ReaderNumRecordsProduced")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ReaderNumWorkUnitsCompleted")(
+    common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ReaderSerializeState")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ReaderRead")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ReaderReadUpTo")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ReaderReset")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ReaderRestoreState")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ReadFile")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("MatchingFiles")(common_shapes.call_cpp_shape_fn)
diff --git a/tensorflow/python/training/training_ops.py b/tensorflow/python/training/training_ops.py
index b1401e51759..fbb5bb32c1b 100644
--- a/tensorflow/python/training/training_ops.py
+++ b/tensorflow/python/training/training_ops.py
@@ -19,6 +19,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.framework import common_shapes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.training import gen_training_ops
@@ -48,246 +49,23 @@ def _AssertInputIsScalar(op, index):
   op.inputs[index].get_shape().assert_is_compatible_with(tensor_shape.scalar())
 
 
-@ops.RegisterShape("ApplyAdadelta")
-def _ApplyAdadeltaShape(op):
-  """Shape function for the ApplyAdadelta op."""
-  var_shape = op.inputs[0].get_shape()
-  accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
-  accum_update_shape = op.inputs[2].get_shape().merge_with(var_shape)
-  _AssertInputIsScalar(op, 3)  # lr
-  _AssertInputIsScalar(op, 4)  # rho
-  _AssertInputIsScalar(op, 5)  # epsilon
-  grad_shape = op.inputs[6].get_shape().merge_with(accum_shape)
-  return [grad_shape]
-
-@ops.RegisterShape("ApplyAdagrad")
-def _ApplyAdagradShape(op):
-  """Shape function for the ApplyAdagrad op."""
-  var_shape = op.inputs[0].get_shape()
-  accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
-  _AssertInputIsScalar(op, 2)  # lr
-  grad_shape = op.inputs[3].get_shape().merge_with(accum_shape)
-  return [grad_shape]
-
-@ops.RegisterShape("ApplyProximalAdagrad")
-def _ApplyProximalAdagradShape(op):
-  """Shape function for the ApplyProximalAdagrad op."""
-  var_shape = op.inputs[0].get_shape()
-  accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
-  _AssertInputIsScalar(op, 2)  # lr
-  _AssertInputIsScalar(op, 3)  # l1
-  _AssertInputIsScalar(op, 4)  # l2
-  grad_shape = op.inputs[5].get_shape().merge_with(accum_shape)
-  return [grad_shape]
-
-
-@ops.RegisterShape("ApplyFtrl")
-def _ApplyFtrlShape(op):
-  """Shape function for the ApplyFtrlOp op."""
-  var_shape = op.inputs[0].get_shape()
-  accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
-  linear_shape = op.inputs[2].get_shape().merge_with(accum_shape)
-  grad_shape = op.inputs[3].get_shape().merge_with(linear_shape)
-  _AssertInputIsScalar(op, 4)  # lr
-  _AssertInputIsScalar(op, 5)  # l1
-  _AssertInputIsScalar(op, 6)  # l2
-  _AssertInputIsScalar(op, 7)  # lr_power
-  return [grad_shape]
-
-
-@ops.RegisterShape("ApplyAdagradDA")
-def ApplyAdagradDAShape(op):
-  """Shape function for the ApplyAdagradDAOp op."""
-  var_shape = op.inputs[0].get_shape()
-  g_accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
-  gg_accum_shape = op.inputs[2].get_shape().merge_with(g_accum_shape)
-  grad_shape = op.inputs[3].get_shape().merge_with(gg_accum_shape)
-  _AssertInputIsScalar(op, 4)  # lr
-  _AssertInputIsScalar(op, 5)  # l1
-  _AssertInputIsScalar(op, 6)  # l2
-  _AssertInputIsScalar(op, 7)  # global_step
-  return [grad_shape]
-
-
-@ops.RegisterShape("ApplyAdam")
-def _ApplyAdamShape(op):
-  """Shape function for the ApplyAdam op."""
-  var_shape = op.inputs[0].get_shape()
-  m_shape = op.inputs[1].get_shape().merge_with(var_shape)
-  v_shape = op.inputs[2].get_shape().merge_with(m_shape)
-  _AssertInputIsScalar(op, 3)  # beta1_power
-  _AssertInputIsScalar(op, 4)  # beta2_power
-  _AssertInputIsScalar(op, 5)  # lr
-  _AssertInputIsScalar(op, 6)  # beta1
-  _AssertInputIsScalar(op, 7)  # beta2
-  _AssertInputIsScalar(op, 8)  # epsilon
-  grad_shape = op.inputs[9].get_shape().merge_with(v_shape)
-  return [grad_shape]
-
-
-@ops.RegisterShape("ApplyMomentum")
-def _ApplyMomentumShape(op):
-  """Shape function for the ApplyMomentum op."""
-  var_shape = op.inputs[0].get_shape()
-  accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
-  _AssertInputIsScalar(op, 2)  # lr
-  grad_shape = op.inputs[3].get_shape().merge_with(accum_shape)
-  _AssertInputIsScalar(op, 4)  # momentum
-  return [grad_shape]
-
-
-@ops.RegisterShape("ApplyRMSProp")
-def _ApplyRMSPropShape(op):
-  """Shape function for the ApplyRMSProp op."""
-  var_shape = op.inputs[0].get_shape()
-  ms_shape = op.inputs[1].get_shape().merge_with(var_shape)
-  mom_shape = op.inputs[2].get_shape().merge_with(ms_shape)
-  _AssertInputIsScalar(op, 3)  # lr
-  _AssertInputIsScalar(op, 4)  # rho
-  _AssertInputIsScalar(op, 5)  # momentum
-  _AssertInputIsScalar(op, 6)  # epsilon
-  grad_shape = op.inputs[7].get_shape().merge_with(mom_shape)
-  return [grad_shape]
-
-
-@ops.RegisterShape("ApplyGradientDescent")
-def _ApplyGradientDescentShape(op):
-  """Shape function for the ApplyGradientDescent op."""
-  var_shape = op.inputs[0].get_shape()
-  _AssertInputIsScalar(op, 1)  # alpha
-  delta_shape = op.inputs[2].get_shape().merge_with(var_shape)
-  return [delta_shape]
-
-
-@ops.RegisterShape("ApplyProximalGradientDescent")
-def _ApplyProximalGradientDescentShape(op):
-  """Shape function for the ApplyProximalGradientDescent op."""
-  var_shape = op.inputs[0].get_shape()
-  _AssertInputIsScalar(op, 1)  # alpha
-  _AssertInputIsScalar(op, 2)  # l1
-  _AssertInputIsScalar(op, 3)  # l2
-  delta_shape = op.inputs[4].get_shape().merge_with(var_shape)
-  return [delta_shape]
-
-
-@ops.RegisterShape("SparseApplyProximalGradientDescent")
-def _SparseApplyProximalGradientDescentShape(op):
-  """Shape function for the SparseApplyGradientDescent op."""
-  var_shape = op.inputs[0].get_shape()
-  _AssertInputIsScalar(op, 1)  # lr
-  _AssertInputIsScalar(op, 2)  # l1
-  _AssertInputIsScalar(op, 3)  # l2
-  grad_shape = op.inputs[4].get_shape().merge_with(
-      tensor_shape.TensorShape([None]).concatenate(var_shape[1:]))
-  unused_indices_shape = op.inputs[5].get_shape().merge_with(
-      tensor_shape.vector(grad_shape[0]))
-  return [var_shape]
-
-
-@ops.RegisterShape("SparseApplyRMSProp")
-def _SparseApplyRMSPropShape(op):
-  """Shape function for the SparseApplyRMSProp op."""
-  var_shape = op.inputs[0].get_shape()
-  ms_shape = op.inputs[1].get_shape().merge_with(var_shape)
-  mom_shape = op.inputs[2].get_shape().merge_with(ms_shape)
-  _AssertInputIsScalar(op, 3)  # lr
-  _AssertInputIsScalar(op, 4)  # rho
-  _AssertInputIsScalar(op, 5)  # momentum
-  _AssertInputIsScalar(op, 6)  # epsilon
-  grad_shape = op.inputs[7].get_shape().merge_with(
-      tensor_shape.TensorShape([None]).concatenate(mom_shape[1:]))
-  unused_indices_shape = op.inputs[8].get_shape().merge_with(
-      tensor_shape.vector(grad_shape[0]))
-  return [mom_shape]
-
-
-@ops.RegisterShape("SparseApplyAdadelta")
-def _SparseApplyAdadeltaShape(op):
-  """Shape function for the SparseApplyAdadelta op."""
-  var_shape = op.inputs[0].get_shape()
-  accum_grad_shape = op.inputs[1].get_shape().merge_with(var_shape)
-  accum_update_shape = op.inputs[2].get_shape().merge_with(accum_grad_shape)
-  _AssertInputIsScalar(op, 3)  # lr
-  _AssertInputIsScalar(op, 4)  # decay_rate
-  _AssertInputIsScalar(op, 5)  # epsilon
-  grad_shape = op.inputs[6].get_shape().merge_with(
-      tensor_shape.TensorShape([None]).concatenate(accum_update_shape[1:]))
-  unused_indices_shape = op.inputs[7].get_shape().merge_with(
-      tensor_shape.vector(grad_shape[0]))
-  return [accum_update_shape]
-
-
-@ops.RegisterShape("SparseApplyAdagrad")
-def _SparseApplyAdagradShape(op):
-  """Shape function for the SparseApplyAdagrad op."""
-  var_shape = op.inputs[0].get_shape()
-  accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
-  _AssertInputIsScalar(op, 2)  # lr
-  grad_shape = op.inputs[3].get_shape().merge_with(
-      tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
-  unused_indices_shape = op.inputs[4].get_shape().merge_with(
-      tensor_shape.vector(grad_shape[0]))
-  return [accum_shape]
-
-
-@ops.RegisterShape("SparseApplyProximalAdagrad")
-def _SparseApplyProximalAdagradShape(op):
-  """Shape function for the SparseApplyProximalAdagrad op."""
-  var_shape = op.inputs[0].get_shape()
-  accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
-  _AssertInputIsScalar(op, 2)  # lr
-  _AssertInputIsScalar(op, 3)  # l1
-  _AssertInputIsScalar(op, 4)  # l2
-  grad_shape = op.inputs[5].get_shape().merge_with(
-      tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
-  unused_indices_shape = op.inputs[6].get_shape().merge_with(
-      tensor_shape.vector(grad_shape[0]))
-  return [accum_shape]
-
-
-@ops.RegisterShape("SparseApplyFtrl")
-def _SparseApplyFtrlShape(op):
-  """Shape function for the SparseApplyFtrl op."""
-  var_shape = op.inputs[0].get_shape()
-  accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
-  linear_shape = op.inputs[2].get_shape().merge_with(accum_shape)
-  grad_shape = op.inputs[3].get_shape().merge_with(
-      tensor_shape.TensorShape([None]).concatenate(linear_shape[1:]))
-  unused_indices_shape = op.inputs[4].get_shape().merge_with(
-      tensor_shape.vector(grad_shape[0]))
-  _AssertInputIsScalar(op, 5)  # lr
-  _AssertInputIsScalar(op, 6)  # l1
-  _AssertInputIsScalar(op, 7)  # l2
-  _AssertInputIsScalar(op, 8)  # lr_power
-  return [linear_shape]
-
-
-@ops.RegisterShape("SparseApplyAdagradDA")
-def _SparseApplyAdagradDAShape(op):
-  """Shape function for the SparseApplyAdagradDA op."""
-  var_shape = op.inputs[0].get_shape()
-  g_accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
-  gg_accum_shape = op.inputs[2].get_shape().merge_with(g_accum_shape)
-  grad_shape = op.inputs[3].get_shape().merge_with(
-      tensor_shape.TensorShape([None]).concatenate(gg_accum_shape[1:]))
-  unused_indices_shape = op.inputs[4].get_shape().merge_with(
-      tensor_shape.vector(grad_shape[0]))
-  _AssertInputIsScalar(op, 5)  # lr
-  _AssertInputIsScalar(op, 6)  # l1
-  _AssertInputIsScalar(op, 7)  # l2
-  _AssertInputIsScalar(op, 8)  # global_step
-  return [gg_accum_shape]
-
-
-@ops.RegisterShape("SparseApplyMomentum")
-def _SparseApplyMomentumShape(op):
-  """Shape function for the SparseApplyMomentum op."""
-  var_shape = op.inputs[0].get_shape()
-  accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
-  _AssertInputIsScalar(op, 2)  # lr
-  grad_shape = op.inputs[3].get_shape().merge_with(
-      tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
-  unused_indices_shape = op.inputs[4].get_shape().merge_with(
-      tensor_shape.vector(grad_shape[0]))
-  _AssertInputIsScalar(op, 5)  # momentum
-  return [accum_shape]
+ops.RegisterShape("ApplyAdadelta")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ApplyAdagrad")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ApplyProximalAdagrad")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ApplyFtrl")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ApplyAdagradDA")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ApplyAdam")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ApplyMomentum")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ApplyRMSProp")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ApplyGradientDescent")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ApplyProximalGradientDescent")(
+    common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("SparseApplyProximalGradientDescent")(
+    common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("SparseApplyRMSProp")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("SparseApplyAdadelta")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("SparseApplyAdagrad")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("SparseApplyProximalAdagrad")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("SparseApplyFtrl")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("SparseApplyAdagradDA")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("SparseApplyMomentum")(common_shapes.call_cpp_shape_fn)