diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc index ab82617a136..2c6a9cde91b 100644 --- a/tensorflow/core/ops/training_ops.cc +++ b/tensorflow/core/ops/training_ops.cc @@ -38,10 +38,11 @@ static Status HandleGradAndIndicesInputs(InferenceContext* c, bool sparse, DimensionHandle unused; TF_RETURN_IF_ERROR(c->Merge(c->Dim(indices, 0), c->Dim(grad, 0), &unused)); - // Trailing part of grad matches *s. - ShapeHandle grad_subshape; - TF_RETURN_IF_ERROR(c->Subshape(grad, 1, &grad_subshape)); - TF_RETURN_IF_ERROR(c->Merge(*s, grad_subshape, s)); + // Trailing part of grad matches trailing part of *s. + ShapeHandle grad_unknown_first; + TF_RETURN_IF_ERROR( + c->ReplaceDim(grad, 0, c->UnknownDim(), &grad_unknown_first)); + TF_RETURN_IF_ERROR(c->Merge(*s, grad_unknown_first, s)); return Status::OK(); } diff --git a/tensorflow/core/ops/training_ops_test.cc b/tensorflow/core/ops/training_ops_test.cc index d1023a1e73d..9c3489211c8 100644 --- a/tensorflow/core/ops/training_ops_test.cc +++ b/tensorflow/core/ops/training_ops_test.cc @@ -30,15 +30,14 @@ static void TestGradAndIndicesErrorHandling(ShapeInferenceTestOp op, grad_indices_spec, shape_spec_end); }; - // mismatch between grad[1] and var[0]. - INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op, - shape_spec("[1]", "[?,2];[?]").c_str()); + // mismatch between grad[1] and var[1]. + INFER_ERROR("Dimension 1 in both shapes must be equal", op, + shape_spec("[?,1]", "[?,2];[?]").c_str()); // grad[0] and indices[0] must match. INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, shape_spec("?", "[2,?];[1]").c_str()); // grad is wrong rank. - INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op, - shape_spec("[1]", "[2];[?]").c_str()); + INFER_ERROR("must be equal rank", op, shape_spec("[1]", "[?,2];[?]").c_str()); // indices is wrong rank. INFER_ERROR("Shape must be rank 1 but is rank 2", op, shape_spec("[?]", "[?];[1,2]").c_str()); @@ -74,7 +73,7 @@ TEST(TrainingOpsTest, SparseApplyProximalGradientDescent_ShapeFn) { ShapeInferenceTestOp op("SparseApplyProximalGradientDescent"); // Output is a merge of inputs 0 (var) and the non-indices part of 4 (delta). - INFER_OK(op, "[1,?];[];[];[];[?,?,2];[3]", "[d0_0,d4_2]"); + INFER_OK(op, "[1,?];[];[];[];[?,2];[3]", "[d0_0,d4_1]"); TestGradAndIndicesErrorHandling(op, "[];[];[]"); @@ -109,14 +108,14 @@ TEST(TrainingOpsTest, SparseApplyAdadelta_ShapeFn) { // Output is a merge of inputs 0, 1, 2, and non-indices part of 6 (var, accum, // accum_update, grad). - INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[];[];[];[?,?,?,?,4];?", - "[d0_0,d1_1,d2_2,d6_4]"); + INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[];[];[];[?,?,?,4];?", + "[d0_0,d1_1,d2_2,d6_3]"); INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op, "[1];[2];[1];[];[];[];[1];?"); INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op, "[1];[1];[2];[];[];[];[1];?"); - INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op, - "[1];[1];[1];[];[];[];[?,2];?"); + INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op, + "[?,1];[?,1];[?,1];[];[];[];[?,2];?"); TestGradAndIndicesErrorHandling(op, "?;?;?;?;?"); @@ -145,11 +144,11 @@ TEST(TrainingOpsTest, SparseApplyAdagrad_ShapeFn) { // Output is a merge of inputs 0, 1, and non-indices part of 3 (var, accum, // grad). - INFER_OK(op, "[1,?,?];[?,2,?];[];[?,?,?,3];?", "[d0_0,d1_1,d3_3]"); - INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op, - "[1];[2];[];[1];?"); - INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op, - "[1];[1];[];[2];?"); + INFER_OK(op, "[1,?,?];[?,2,?];[];[?,?,3];?", "[d0_0,d1_1,d3_2]"); + INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op, + "[?,1];[?,2];[];[?,1];?"); + INFER_ERROR("Shapes must be equal rank, but are 2 and 3", op, + "[?,1];[?,1];[];[?,?,2];?"); TestGradAndIndicesErrorHandling(op, "?;?"); @@ -178,11 +177,11 @@ TEST(TrainingOpsTest, SparseApplyProximalAdagrad_ShapeFn) { // Output is a merge of inputs 0, 1, and the non-indices part of 5 (var, // accum, grad). - INFER_OK(op, "[1,?,?];[?,2,?];[];[];[];[?,?,?,3];?", "[d0_0,d1_1,d5_3]"); + INFER_OK(op, "[1,?,?];[?,2,?];[];[];[];[?,?,3];?", "[d0_0,d1_1,d5_2]"); INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op, "[1];[2];[];[];[];[?,1];?"); - INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op, - "[1];[1];[];[];[];[?,2];?"); + INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op, + "[?,1];[?,1];[];[];[];[?,2];?"); TestGradAndIndicesErrorHandling(op, "?;?;?;?"); @@ -217,14 +216,14 @@ TEST(TrainingOpsTest, SparseApplyFtrl_ShapeFn) { // Output is a merge of inputs 0, 1, 2, and non-indices part of 3 (var, accum, // linear, grad). - INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[?,?,?,?,4];?;[];[];[];[]", - "[d0_0,d1_1,d2_2,d3_4]"); + INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[?,?,?,4];?;[];[];[];[]", + "[d0_0,d1_1,d2_2,d3_3]"); INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op, "[1];[2];[1];[?,1];?;[];[];[];[]"); INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op, "[1];[1];[2];[?,1];?;[];[];[];[]"); - INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op, - "[1];[1];[1];[?,2];?;[];[];[];[]"); + INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op, + "[?,1];[?,1];[?,1];[?,2];?;[];[];[];[]"); TestGradAndIndicesErrorHandling(op, "?;?", ";?;?;?;?"); @@ -255,11 +254,11 @@ TEST(TrainingOpsTest, SparseApplyMomentum_ShapeFn) { // Output is a merge of inputs 0, 1, and non-indices part of 3 (var, accum, // grad). - INFER_OK(op, "[1,?,?];[?,2,?];[];[?,?,?,3];?;[]", "[d0_0,d1_1,d3_3]"); - INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op, - "[1];[2];[];[?,1];?;[]"); - INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op, - "[1];[1];[];[?,2];?;[]"); + INFER_OK(op, "[1,?,?];[?,2,?];[];[?,?,3];?;[]", "[d0_0,d1_1,d3_2]"); + INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op, + "[?,1];[?,2];[];[?,1];?;[]"); + INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op, + "[?,1];[?,1];[];[?,2];?;[]"); TestGradAndIndicesErrorHandling(op, "?;?", ";?"); @@ -316,14 +315,14 @@ TEST(TrainingOpsTest, SparseApplyRMSProp_ShapeFn) { // Output is a merge of inputs 0, 1, 2, and the non-indices part of 7 (var, // ms, mom, and grad). - INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[];[];[];[];[?,?,?,?,4];?", - "[d0_0,d1_1,d2_2,d7_4]"); + INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[];[];[];[];[?,?,?,4];?", + "[d0_0,d1_1,d2_2,d7_3]"); INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op, - "[1];[2];[1];[];[];[];[];[?,1];?"); + "[1];[2];[1];[];[];[];[];[1];?"); INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op, - "[1];[1];[2];[];[];[];[];[?,1];?"); - INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op, - "[1];[1];[1];[];[];[];[];[?,2];?"); + "[1];[1];[2];[];[];[];[];[1];?"); + INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op, + "[?,1];[?,1];[?,1];[];[];[];[];[?,2];?"); TestGradAndIndicesErrorHandling(op, "?;?;?;?;?;?"); diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index 7990dba3b63..df34de3da41 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -138,7 +138,6 @@ from __future__ import print_function from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.ops import gen_io_ops # go/tf-wildcard-import @@ -205,80 +204,12 @@ def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type, preferred_shard, name=name) -@ops.RegisterShape("Restore") -def _RestoreShape(op): - """Shape function for Restore op.""" - # Validate input shapes. - unused_file_pattern = op.inputs[0].get_shape().merge_with( - tensor_shape.scalar()) - unused_tensor_name = op.inputs[1].get_shape().merge_with( - tensor_shape.scalar()) - return [tensor_shape.unknown_shape()] - - -@ops.RegisterShape("RestoreSlice") -def _RestoreSliceShape(op): - """Shape function for RestoreSlice op.""" - # Validate input shapes. - unused_file_pattern = op.inputs[0].get_shape().merge_with( - tensor_shape.scalar()) - unused_tensor_name = op.inputs[1].get_shape().merge_with( - tensor_shape.scalar()) - unused_shape_and_slice_shape = op.inputs[2].get_shape().merge_with( - tensor_shape.scalar()) - # TODO(mrry): Attempt to parse the shape_and_slice value and use it - # to form the shape of the output. - return [tensor_shape.unknown_shape()] - - -@ops.RegisterShape("Save") -def _SaveShape(op): - """Shape function for Save op.""" - # Validate input shapes. - unused_filename = op.inputs[0].get_shape().merge_with(tensor_shape.scalar()) - data_count = len(op.inputs) - 2 - unused_tensor_names_shape = op.inputs[1].get_shape().merge_with( - tensor_shape.vector(data_count)) - return [] - - -@ops.RegisterShape("SaveSlices") -def _SaveSlicesShape(op): - """Shape function for SaveSlices op.""" - # Validate input shapes. - unused_filename = op.inputs[0].get_shape().merge_with(tensor_shape.scalar()) - data_count = len(op.inputs) - 3 - unused_tensor_names_shape = op.inputs[1].get_shape().merge_with( - tensor_shape.vector(data_count)) - unused_shapes_and_slices_shape = op.inputs[2].get_shape().merge_with( - tensor_shape.vector(data_count)) - # TODO(mrry): Attempt to parse the shapes_and_slices values and use - # them to constrain the shape of the remaining inputs. - return [] - - -@ops.RegisterShape("ShardedFilename") -def _ShardedFilenameShape(op): - """Shape function for ShardedFilename op.""" - # Validate input shapes. - unused_basename_shape = op.inputs[0].get_shape().merge_with( - tensor_shape.scalar()) - unused_shard_shape = op.inputs[1].get_shape().merge_with( - tensor_shape.scalar()) - unused_num_shards_shape = op.inputs[2].get_shape().merge_with( - tensor_shape.scalar()) - return [tensor_shape.scalar()] - - -@ops.RegisterShape("ShardedFilespec") -def _ShardedFilespecShape(op): - """Shape function for ShardedFilespec op.""" - # Validate input shapes. - unused_basename_shape = op.inputs[0].get_shape().merge_with( - tensor_shape.scalar()) - unused_num_shards_shape = op.inputs[1].get_shape().merge_with( - tensor_shape.scalar()) - return [tensor_shape.scalar()] +ops.RegisterShape("Restore")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("RestoreSlice")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("Save")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("SaveSlices")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ShardedFilename")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ShardedFilespec")(common_shapes.call_cpp_shape_fn) class ReaderBase(object): @@ -574,61 +505,13 @@ ops.RegisterShape("TextLineReader")(common_shapes.scalar_shape) ops.RegisterShape("WholeFileReader")(common_shapes.scalar_shape) ops.RegisterShape("TFRecordReader")(common_shapes.scalar_shape) - -@ops.RegisterShape("ReaderNumRecordsProduced") -@ops.RegisterShape("ReaderNumWorkUnitsCompleted") -@ops.RegisterShape("ReaderSerializeState") -def _ReaderScalarShape(op): - """Shape function for ops that transform a reader to a scalar.""" - unused_handle_shape = op.inputs[0].get_shape().merge_with( - tensor_shape.scalar()) - return [tensor_shape.scalar()] - - -@ops.RegisterShape("ReaderRead") -def _ReaderReadShape(op): - """Shape function for the ReaderBase.Read op.""" - unused_handle_shape = op.inputs[0].get_shape().merge_with( - tensor_shape.scalar()) - unused_queue_shape = op.inputs[1].get_shape().merge_with( - tensor_shape.scalar()) - return [tensor_shape.scalar(), tensor_shape.scalar()] - - -@ops.RegisterShape("ReaderReadUpTo") -def _ReaderReadUpToShape(_): - """Shape function for the ReaderBase.ReadUpTo op.""" - return [tensor_shape.unknown_shape(ndims=1), - tensor_shape.unknown_shape(ndims=1)] - - -@ops.RegisterShape("ReaderReset") -def _ReaderResetShape(op): - """Shape function for the ReaderBase.Reset op.""" - unused_handle_shape = op.inputs[0].get_shape().merge_with( - tensor_shape.scalar()) - return [] - - -@ops.RegisterShape("ReaderRestoreState") -def _ReaderRestoreStateShape(op): - """Shape function for the ReaderBase.Restore op.""" - unused_handle_shape = op.inputs[0].get_shape().merge_with( - tensor_shape.scalar()) - unused_state_shape = op.inputs[1].get_shape().merge_with( - tensor_shape.scalar()) - return [] - - -@ops.RegisterShape("ReadFile") -def _ReadFileShape(op): - """Shape function for the ReadFile op.""" - return [op.inputs[0].get_shape().merge_with(tensor_shape.scalar())] - - -@ops.RegisterShape("MatchingFiles") -def _MatchingFilesShape(op): - """Shape function for the MatchingFiles op.""" - unused_patern_shape = op.inputs[0].get_shape().merge_with( - tensor_shape.scalar()) - return [tensor_shape.unknown_shape(ndims=1)] +ops.RegisterShape("ReaderNumRecordsProduced")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ReaderNumWorkUnitsCompleted")( + common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ReaderSerializeState")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ReaderRead")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ReaderReadUpTo")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ReaderReset")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ReaderRestoreState")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ReadFile")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("MatchingFiles")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/python/training/training_ops.py b/tensorflow/python/training/training_ops.py index b1401e51759..fbb5bb32c1b 100644 --- a/tensorflow/python/training/training_ops.py +++ b/tensorflow/python/training/training_ops.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import common_shapes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.training import gen_training_ops @@ -48,246 +49,23 @@ def _AssertInputIsScalar(op, index): op.inputs[index].get_shape().assert_is_compatible_with(tensor_shape.scalar()) -@ops.RegisterShape("ApplyAdadelta") -def _ApplyAdadeltaShape(op): - """Shape function for the ApplyAdadelta op.""" - var_shape = op.inputs[0].get_shape() - accum_shape = op.inputs[1].get_shape().merge_with(var_shape) - accum_update_shape = op.inputs[2].get_shape().merge_with(var_shape) - _AssertInputIsScalar(op, 3) # lr - _AssertInputIsScalar(op, 4) # rho - _AssertInputIsScalar(op, 5) # epsilon - grad_shape = op.inputs[6].get_shape().merge_with(accum_shape) - return [grad_shape] - -@ops.RegisterShape("ApplyAdagrad") -def _ApplyAdagradShape(op): - """Shape function for the ApplyAdagrad op.""" - var_shape = op.inputs[0].get_shape() - accum_shape = op.inputs[1].get_shape().merge_with(var_shape) - _AssertInputIsScalar(op, 2) # lr - grad_shape = op.inputs[3].get_shape().merge_with(accum_shape) - return [grad_shape] - -@ops.RegisterShape("ApplyProximalAdagrad") -def _ApplyProximalAdagradShape(op): - """Shape function for the ApplyProximalAdagrad op.""" - var_shape = op.inputs[0].get_shape() - accum_shape = op.inputs[1].get_shape().merge_with(var_shape) - _AssertInputIsScalar(op, 2) # lr - _AssertInputIsScalar(op, 3) # l1 - _AssertInputIsScalar(op, 4) # l2 - grad_shape = op.inputs[5].get_shape().merge_with(accum_shape) - return [grad_shape] - - -@ops.RegisterShape("ApplyFtrl") -def _ApplyFtrlShape(op): - """Shape function for the ApplyFtrlOp op.""" - var_shape = op.inputs[0].get_shape() - accum_shape = op.inputs[1].get_shape().merge_with(var_shape) - linear_shape = op.inputs[2].get_shape().merge_with(accum_shape) - grad_shape = op.inputs[3].get_shape().merge_with(linear_shape) - _AssertInputIsScalar(op, 4) # lr - _AssertInputIsScalar(op, 5) # l1 - _AssertInputIsScalar(op, 6) # l2 - _AssertInputIsScalar(op, 7) # lr_power - return [grad_shape] - - -@ops.RegisterShape("ApplyAdagradDA") -def ApplyAdagradDAShape(op): - """Shape function for the ApplyAdagradDAOp op.""" - var_shape = op.inputs[0].get_shape() - g_accum_shape = op.inputs[1].get_shape().merge_with(var_shape) - gg_accum_shape = op.inputs[2].get_shape().merge_with(g_accum_shape) - grad_shape = op.inputs[3].get_shape().merge_with(gg_accum_shape) - _AssertInputIsScalar(op, 4) # lr - _AssertInputIsScalar(op, 5) # l1 - _AssertInputIsScalar(op, 6) # l2 - _AssertInputIsScalar(op, 7) # global_step - return [grad_shape] - - -@ops.RegisterShape("ApplyAdam") -def _ApplyAdamShape(op): - """Shape function for the ApplyAdam op.""" - var_shape = op.inputs[0].get_shape() - m_shape = op.inputs[1].get_shape().merge_with(var_shape) - v_shape = op.inputs[2].get_shape().merge_with(m_shape) - _AssertInputIsScalar(op, 3) # beta1_power - _AssertInputIsScalar(op, 4) # beta2_power - _AssertInputIsScalar(op, 5) # lr - _AssertInputIsScalar(op, 6) # beta1 - _AssertInputIsScalar(op, 7) # beta2 - _AssertInputIsScalar(op, 8) # epsilon - grad_shape = op.inputs[9].get_shape().merge_with(v_shape) - return [grad_shape] - - -@ops.RegisterShape("ApplyMomentum") -def _ApplyMomentumShape(op): - """Shape function for the ApplyMomentum op.""" - var_shape = op.inputs[0].get_shape() - accum_shape = op.inputs[1].get_shape().merge_with(var_shape) - _AssertInputIsScalar(op, 2) # lr - grad_shape = op.inputs[3].get_shape().merge_with(accum_shape) - _AssertInputIsScalar(op, 4) # momentum - return [grad_shape] - - -@ops.RegisterShape("ApplyRMSProp") -def _ApplyRMSPropShape(op): - """Shape function for the ApplyRMSProp op.""" - var_shape = op.inputs[0].get_shape() - ms_shape = op.inputs[1].get_shape().merge_with(var_shape) - mom_shape = op.inputs[2].get_shape().merge_with(ms_shape) - _AssertInputIsScalar(op, 3) # lr - _AssertInputIsScalar(op, 4) # rho - _AssertInputIsScalar(op, 5) # momentum - _AssertInputIsScalar(op, 6) # epsilon - grad_shape = op.inputs[7].get_shape().merge_with(mom_shape) - return [grad_shape] - - -@ops.RegisterShape("ApplyGradientDescent") -def _ApplyGradientDescentShape(op): - """Shape function for the ApplyGradientDescent op.""" - var_shape = op.inputs[0].get_shape() - _AssertInputIsScalar(op, 1) # alpha - delta_shape = op.inputs[2].get_shape().merge_with(var_shape) - return [delta_shape] - - -@ops.RegisterShape("ApplyProximalGradientDescent") -def _ApplyProximalGradientDescentShape(op): - """Shape function for the ApplyProximalGradientDescent op.""" - var_shape = op.inputs[0].get_shape() - _AssertInputIsScalar(op, 1) # alpha - _AssertInputIsScalar(op, 2) # l1 - _AssertInputIsScalar(op, 3) # l2 - delta_shape = op.inputs[4].get_shape().merge_with(var_shape) - return [delta_shape] - - -@ops.RegisterShape("SparseApplyProximalGradientDescent") -def _SparseApplyProximalGradientDescentShape(op): - """Shape function for the SparseApplyGradientDescent op.""" - var_shape = op.inputs[0].get_shape() - _AssertInputIsScalar(op, 1) # lr - _AssertInputIsScalar(op, 2) # l1 - _AssertInputIsScalar(op, 3) # l2 - grad_shape = op.inputs[4].get_shape().merge_with( - tensor_shape.TensorShape([None]).concatenate(var_shape[1:])) - unused_indices_shape = op.inputs[5].get_shape().merge_with( - tensor_shape.vector(grad_shape[0])) - return [var_shape] - - -@ops.RegisterShape("SparseApplyRMSProp") -def _SparseApplyRMSPropShape(op): - """Shape function for the SparseApplyRMSProp op.""" - var_shape = op.inputs[0].get_shape() - ms_shape = op.inputs[1].get_shape().merge_with(var_shape) - mom_shape = op.inputs[2].get_shape().merge_with(ms_shape) - _AssertInputIsScalar(op, 3) # lr - _AssertInputIsScalar(op, 4) # rho - _AssertInputIsScalar(op, 5) # momentum - _AssertInputIsScalar(op, 6) # epsilon - grad_shape = op.inputs[7].get_shape().merge_with( - tensor_shape.TensorShape([None]).concatenate(mom_shape[1:])) - unused_indices_shape = op.inputs[8].get_shape().merge_with( - tensor_shape.vector(grad_shape[0])) - return [mom_shape] - - -@ops.RegisterShape("SparseApplyAdadelta") -def _SparseApplyAdadeltaShape(op): - """Shape function for the SparseApplyAdadelta op.""" - var_shape = op.inputs[0].get_shape() - accum_grad_shape = op.inputs[1].get_shape().merge_with(var_shape) - accum_update_shape = op.inputs[2].get_shape().merge_with(accum_grad_shape) - _AssertInputIsScalar(op, 3) # lr - _AssertInputIsScalar(op, 4) # decay_rate - _AssertInputIsScalar(op, 5) # epsilon - grad_shape = op.inputs[6].get_shape().merge_with( - tensor_shape.TensorShape([None]).concatenate(accum_update_shape[1:])) - unused_indices_shape = op.inputs[7].get_shape().merge_with( - tensor_shape.vector(grad_shape[0])) - return [accum_update_shape] - - -@ops.RegisterShape("SparseApplyAdagrad") -def _SparseApplyAdagradShape(op): - """Shape function for the SparseApplyAdagrad op.""" - var_shape = op.inputs[0].get_shape() - accum_shape = op.inputs[1].get_shape().merge_with(var_shape) - _AssertInputIsScalar(op, 2) # lr - grad_shape = op.inputs[3].get_shape().merge_with( - tensor_shape.TensorShape([None]).concatenate(accum_shape[1:])) - unused_indices_shape = op.inputs[4].get_shape().merge_with( - tensor_shape.vector(grad_shape[0])) - return [accum_shape] - - -@ops.RegisterShape("SparseApplyProximalAdagrad") -def _SparseApplyProximalAdagradShape(op): - """Shape function for the SparseApplyProximalAdagrad op.""" - var_shape = op.inputs[0].get_shape() - accum_shape = op.inputs[1].get_shape().merge_with(var_shape) - _AssertInputIsScalar(op, 2) # lr - _AssertInputIsScalar(op, 3) # l1 - _AssertInputIsScalar(op, 4) # l2 - grad_shape = op.inputs[5].get_shape().merge_with( - tensor_shape.TensorShape([None]).concatenate(accum_shape[1:])) - unused_indices_shape = op.inputs[6].get_shape().merge_with( - tensor_shape.vector(grad_shape[0])) - return [accum_shape] - - -@ops.RegisterShape("SparseApplyFtrl") -def _SparseApplyFtrlShape(op): - """Shape function for the SparseApplyFtrl op.""" - var_shape = op.inputs[0].get_shape() - accum_shape = op.inputs[1].get_shape().merge_with(var_shape) - linear_shape = op.inputs[2].get_shape().merge_with(accum_shape) - grad_shape = op.inputs[3].get_shape().merge_with( - tensor_shape.TensorShape([None]).concatenate(linear_shape[1:])) - unused_indices_shape = op.inputs[4].get_shape().merge_with( - tensor_shape.vector(grad_shape[0])) - _AssertInputIsScalar(op, 5) # lr - _AssertInputIsScalar(op, 6) # l1 - _AssertInputIsScalar(op, 7) # l2 - _AssertInputIsScalar(op, 8) # lr_power - return [linear_shape] - - -@ops.RegisterShape("SparseApplyAdagradDA") -def _SparseApplyAdagradDAShape(op): - """Shape function for the SparseApplyAdagradDA op.""" - var_shape = op.inputs[0].get_shape() - g_accum_shape = op.inputs[1].get_shape().merge_with(var_shape) - gg_accum_shape = op.inputs[2].get_shape().merge_with(g_accum_shape) - grad_shape = op.inputs[3].get_shape().merge_with( - tensor_shape.TensorShape([None]).concatenate(gg_accum_shape[1:])) - unused_indices_shape = op.inputs[4].get_shape().merge_with( - tensor_shape.vector(grad_shape[0])) - _AssertInputIsScalar(op, 5) # lr - _AssertInputIsScalar(op, 6) # l1 - _AssertInputIsScalar(op, 7) # l2 - _AssertInputIsScalar(op, 8) # global_step - return [gg_accum_shape] - - -@ops.RegisterShape("SparseApplyMomentum") -def _SparseApplyMomentumShape(op): - """Shape function for the SparseApplyMomentum op.""" - var_shape = op.inputs[0].get_shape() - accum_shape = op.inputs[1].get_shape().merge_with(var_shape) - _AssertInputIsScalar(op, 2) # lr - grad_shape = op.inputs[3].get_shape().merge_with( - tensor_shape.TensorShape([None]).concatenate(accum_shape[1:])) - unused_indices_shape = op.inputs[4].get_shape().merge_with( - tensor_shape.vector(grad_shape[0])) - _AssertInputIsScalar(op, 5) # momentum - return [accum_shape] +ops.RegisterShape("ApplyAdadelta")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ApplyAdagrad")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ApplyProximalAdagrad")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ApplyFtrl")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ApplyAdagradDA")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ApplyAdam")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ApplyMomentum")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ApplyRMSProp")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ApplyGradientDescent")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ApplyProximalGradientDescent")( + common_shapes.call_cpp_shape_fn) +ops.RegisterShape("SparseApplyProximalGradientDescent")( + common_shapes.call_cpp_shape_fn) +ops.RegisterShape("SparseApplyRMSProp")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("SparseApplyAdadelta")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("SparseApplyAdagrad")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("SparseApplyProximalAdagrad")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("SparseApplyFtrl")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("SparseApplyAdagradDA")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("SparseApplyMomentum")(common_shapes.call_cpp_shape_fn)