Delegate to C++ shape functions for python ops in training_ops and io_ops.

Change: 132345086
This commit is contained in:
A. Unique TensorFlower 2016-09-06 10:47:06 -08:00 committed by TensorFlower Gardener
parent 8cec0eebd1
commit 7946b8481f
4 changed files with 74 additions and 413 deletions

View File

@ -38,10 +38,11 @@ static Status HandleGradAndIndicesInputs(InferenceContext* c, bool sparse,
DimensionHandle unused;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(indices, 0), c->Dim(grad, 0), &unused));
// Trailing part of grad matches *s.
ShapeHandle grad_subshape;
TF_RETURN_IF_ERROR(c->Subshape(grad, 1, &grad_subshape));
TF_RETURN_IF_ERROR(c->Merge(*s, grad_subshape, s));
// Trailing part of grad matches trailing part of *s.
ShapeHandle grad_unknown_first;
TF_RETURN_IF_ERROR(
c->ReplaceDim(grad, 0, c->UnknownDim(), &grad_unknown_first));
TF_RETURN_IF_ERROR(c->Merge(*s, grad_unknown_first, s));
return Status::OK();
}

View File

@ -30,15 +30,14 @@ static void TestGradAndIndicesErrorHandling(ShapeInferenceTestOp op,
grad_indices_spec, shape_spec_end);
};
// mismatch between grad[1] and var[0].
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
shape_spec("[1]", "[?,2];[?]").c_str());
// mismatch between grad[1] and var[1].
INFER_ERROR("Dimension 1 in both shapes must be equal", op,
shape_spec("[?,1]", "[?,2];[?]").c_str());
// grad[0] and indices[0] must match.
INFER_ERROR("Dimensions must be equal, but are 1 and 2", op,
shape_spec("?", "[2,?];[1]").c_str());
// grad is wrong rank.
INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op,
shape_spec("[1]", "[2];[?]").c_str());
INFER_ERROR("must be equal rank", op, shape_spec("[1]", "[?,2];[?]").c_str());
// indices is wrong rank.
INFER_ERROR("Shape must be rank 1 but is rank 2", op,
shape_spec("[?]", "[?];[1,2]").c_str());
@ -74,7 +73,7 @@ TEST(TrainingOpsTest, SparseApplyProximalGradientDescent_ShapeFn) {
ShapeInferenceTestOp op("SparseApplyProximalGradientDescent");
// Output is a merge of inputs 0 (var) and the non-indices part of 4 (delta).
INFER_OK(op, "[1,?];[];[];[];[?,?,2];[3]", "[d0_0,d4_2]");
INFER_OK(op, "[1,?];[];[];[];[?,2];[3]", "[d0_0,d4_1]");
TestGradAndIndicesErrorHandling(op, "[];[];[]");
@ -109,14 +108,14 @@ TEST(TrainingOpsTest, SparseApplyAdadelta_ShapeFn) {
// Output is a merge of inputs 0, 1, 2, and non-indices part of 6 (var, accum,
// accum_update, grad).
INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[];[];[];[?,?,?,?,4];?",
"[d0_0,d1_1,d2_2,d6_4]");
INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[];[];[];[?,?,?,4];?",
"[d0_0,d1_1,d2_2,d6_3]");
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
"[1];[2];[1];[];[];[];[1];?");
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
"[1];[1];[2];[];[];[];[1];?");
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
"[1];[1];[1];[];[];[];[?,2];?");
INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
"[?,1];[?,1];[?,1];[];[];[];[?,2];?");
TestGradAndIndicesErrorHandling(op, "?;?;?;?;?");
@ -145,11 +144,11 @@ TEST(TrainingOpsTest, SparseApplyAdagrad_ShapeFn) {
// Output is a merge of inputs 0, 1, and non-indices part of 3 (var, accum,
// grad).
INFER_OK(op, "[1,?,?];[?,2,?];[];[?,?,?,3];?", "[d0_0,d1_1,d3_3]");
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
"[1];[2];[];[1];?");
INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op,
"[1];[1];[];[2];?");
INFER_OK(op, "[1,?,?];[?,2,?];[];[?,?,3];?", "[d0_0,d1_1,d3_2]");
INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
"[?,1];[?,2];[];[?,1];?");
INFER_ERROR("Shapes must be equal rank, but are 2 and 3", op,
"[?,1];[?,1];[];[?,?,2];?");
TestGradAndIndicesErrorHandling(op, "?;?");
@ -178,11 +177,11 @@ TEST(TrainingOpsTest, SparseApplyProximalAdagrad_ShapeFn) {
// Output is a merge of inputs 0, 1, and the non-indices part of 5 (var,
// accum, grad).
INFER_OK(op, "[1,?,?];[?,2,?];[];[];[];[?,?,?,3];?", "[d0_0,d1_1,d5_3]");
INFER_OK(op, "[1,?,?];[?,2,?];[];[];[];[?,?,3];?", "[d0_0,d1_1,d5_2]");
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
"[1];[2];[];[];[];[?,1];?");
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
"[1];[1];[];[];[];[?,2];?");
INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
"[?,1];[?,1];[];[];[];[?,2];?");
TestGradAndIndicesErrorHandling(op, "?;?;?;?");
@ -217,14 +216,14 @@ TEST(TrainingOpsTest, SparseApplyFtrl_ShapeFn) {
// Output is a merge of inputs 0, 1, 2, and non-indices part of 3 (var, accum,
// linear, grad).
INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[?,?,?,?,4];?;[];[];[];[]",
"[d0_0,d1_1,d2_2,d3_4]");
INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[?,?,?,4];?;[];[];[];[]",
"[d0_0,d1_1,d2_2,d3_3]");
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
"[1];[2];[1];[?,1];?;[];[];[];[]");
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
"[1];[1];[2];[?,1];?;[];[];[];[]");
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
"[1];[1];[1];[?,2];?;[];[];[];[]");
INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
"[?,1];[?,1];[?,1];[?,2];?;[];[];[];[]");
TestGradAndIndicesErrorHandling(op, "?;?", ";?;?;?;?");
@ -255,11 +254,11 @@ TEST(TrainingOpsTest, SparseApplyMomentum_ShapeFn) {
// Output is a merge of inputs 0, 1, and non-indices part of 3 (var, accum,
// grad).
INFER_OK(op, "[1,?,?];[?,2,?];[];[?,?,?,3];?;[]", "[d0_0,d1_1,d3_3]");
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
"[1];[2];[];[?,1];?;[]");
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
"[1];[1];[];[?,2];?;[]");
INFER_OK(op, "[1,?,?];[?,2,?];[];[?,?,3];?;[]", "[d0_0,d1_1,d3_2]");
INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
"[?,1];[?,2];[];[?,1];?;[]");
INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
"[?,1];[?,1];[];[?,2];?;[]");
TestGradAndIndicesErrorHandling(op, "?;?", ";?");
@ -316,14 +315,14 @@ TEST(TrainingOpsTest, SparseApplyRMSProp_ShapeFn) {
// Output is a merge of inputs 0, 1, 2, and the non-indices part of 7 (var,
// ms, mom, and grad).
INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[];[];[];[];[?,?,?,?,4];?",
"[d0_0,d1_1,d2_2,d7_4]");
INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[];[];[];[];[?,?,?,4];?",
"[d0_0,d1_1,d2_2,d7_3]");
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
"[1];[2];[1];[];[];[];[];[?,1];?");
"[1];[2];[1];[];[];[];[];[1];?");
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
"[1];[1];[2];[];[];[];[];[?,1];?");
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
"[1];[1];[1];[];[];[];[];[?,2];?");
"[1];[1];[2];[];[];[];[];[1];?");
INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
"[?,1];[?,1];[?,1];[];[];[];[];[?,2];?");
TestGradAndIndicesErrorHandling(op, "?;?;?;?;?;?");

View File

@ -138,7 +138,6 @@ from __future__ import print_function
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import gen_io_ops
# go/tf-wildcard-import
@ -205,80 +204,12 @@ def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type,
preferred_shard, name=name)
@ops.RegisterShape("Restore")
def _RestoreShape(op):
"""Shape function for Restore op."""
# Validate input shapes.
unused_file_pattern = op.inputs[0].get_shape().merge_with(
tensor_shape.scalar())
unused_tensor_name = op.inputs[1].get_shape().merge_with(
tensor_shape.scalar())
return [tensor_shape.unknown_shape()]
@ops.RegisterShape("RestoreSlice")
def _RestoreSliceShape(op):
"""Shape function for RestoreSlice op."""
# Validate input shapes.
unused_file_pattern = op.inputs[0].get_shape().merge_with(
tensor_shape.scalar())
unused_tensor_name = op.inputs[1].get_shape().merge_with(
tensor_shape.scalar())
unused_shape_and_slice_shape = op.inputs[2].get_shape().merge_with(
tensor_shape.scalar())
# TODO(mrry): Attempt to parse the shape_and_slice value and use it
# to form the shape of the output.
return [tensor_shape.unknown_shape()]
@ops.RegisterShape("Save")
def _SaveShape(op):
"""Shape function for Save op."""
# Validate input shapes.
unused_filename = op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
data_count = len(op.inputs) - 2
unused_tensor_names_shape = op.inputs[1].get_shape().merge_with(
tensor_shape.vector(data_count))
return []
@ops.RegisterShape("SaveSlices")
def _SaveSlicesShape(op):
"""Shape function for SaveSlices op."""
# Validate input shapes.
unused_filename = op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
data_count = len(op.inputs) - 3
unused_tensor_names_shape = op.inputs[1].get_shape().merge_with(
tensor_shape.vector(data_count))
unused_shapes_and_slices_shape = op.inputs[2].get_shape().merge_with(
tensor_shape.vector(data_count))
# TODO(mrry): Attempt to parse the shapes_and_slices values and use
# them to constrain the shape of the remaining inputs.
return []
@ops.RegisterShape("ShardedFilename")
def _ShardedFilenameShape(op):
"""Shape function for ShardedFilename op."""
# Validate input shapes.
unused_basename_shape = op.inputs[0].get_shape().merge_with(
tensor_shape.scalar())
unused_shard_shape = op.inputs[1].get_shape().merge_with(
tensor_shape.scalar())
unused_num_shards_shape = op.inputs[2].get_shape().merge_with(
tensor_shape.scalar())
return [tensor_shape.scalar()]
@ops.RegisterShape("ShardedFilespec")
def _ShardedFilespecShape(op):
"""Shape function for ShardedFilespec op."""
# Validate input shapes.
unused_basename_shape = op.inputs[0].get_shape().merge_with(
tensor_shape.scalar())
unused_num_shards_shape = op.inputs[1].get_shape().merge_with(
tensor_shape.scalar())
return [tensor_shape.scalar()]
ops.RegisterShape("Restore")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("RestoreSlice")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("Save")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("SaveSlices")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ShardedFilename")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ShardedFilespec")(common_shapes.call_cpp_shape_fn)
class ReaderBase(object):
@ -574,61 +505,13 @@ ops.RegisterShape("TextLineReader")(common_shapes.scalar_shape)
ops.RegisterShape("WholeFileReader")(common_shapes.scalar_shape)
ops.RegisterShape("TFRecordReader")(common_shapes.scalar_shape)
@ops.RegisterShape("ReaderNumRecordsProduced")
@ops.RegisterShape("ReaderNumWorkUnitsCompleted")
@ops.RegisterShape("ReaderSerializeState")
def _ReaderScalarShape(op):
"""Shape function for ops that transform a reader to a scalar."""
unused_handle_shape = op.inputs[0].get_shape().merge_with(
tensor_shape.scalar())
return [tensor_shape.scalar()]
@ops.RegisterShape("ReaderRead")
def _ReaderReadShape(op):
"""Shape function for the ReaderBase.Read op."""
unused_handle_shape = op.inputs[0].get_shape().merge_with(
tensor_shape.scalar())
unused_queue_shape = op.inputs[1].get_shape().merge_with(
tensor_shape.scalar())
return [tensor_shape.scalar(), tensor_shape.scalar()]
@ops.RegisterShape("ReaderReadUpTo")
def _ReaderReadUpToShape(_):
"""Shape function for the ReaderBase.ReadUpTo op."""
return [tensor_shape.unknown_shape(ndims=1),
tensor_shape.unknown_shape(ndims=1)]
@ops.RegisterShape("ReaderReset")
def _ReaderResetShape(op):
"""Shape function for the ReaderBase.Reset op."""
unused_handle_shape = op.inputs[0].get_shape().merge_with(
tensor_shape.scalar())
return []
@ops.RegisterShape("ReaderRestoreState")
def _ReaderRestoreStateShape(op):
"""Shape function for the ReaderBase.Restore op."""
unused_handle_shape = op.inputs[0].get_shape().merge_with(
tensor_shape.scalar())
unused_state_shape = op.inputs[1].get_shape().merge_with(
tensor_shape.scalar())
return []
@ops.RegisterShape("ReadFile")
def _ReadFileShape(op):
"""Shape function for the ReadFile op."""
return [op.inputs[0].get_shape().merge_with(tensor_shape.scalar())]
@ops.RegisterShape("MatchingFiles")
def _MatchingFilesShape(op):
"""Shape function for the MatchingFiles op."""
unused_patern_shape = op.inputs[0].get_shape().merge_with(
tensor_shape.scalar())
return [tensor_shape.unknown_shape(ndims=1)]
ops.RegisterShape("ReaderNumRecordsProduced")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ReaderNumWorkUnitsCompleted")(
common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ReaderSerializeState")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ReaderRead")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ReaderReadUpTo")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ReaderReset")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ReaderRestoreState")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ReadFile")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("MatchingFiles")(common_shapes.call_cpp_shape_fn)

View File

@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.training import gen_training_ops
@ -48,246 +49,23 @@ def _AssertInputIsScalar(op, index):
op.inputs[index].get_shape().assert_is_compatible_with(tensor_shape.scalar())
@ops.RegisterShape("ApplyAdadelta")
def _ApplyAdadeltaShape(op):
"""Shape function for the ApplyAdadelta op."""
var_shape = op.inputs[0].get_shape()
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
accum_update_shape = op.inputs[2].get_shape().merge_with(var_shape)
_AssertInputIsScalar(op, 3) # lr
_AssertInputIsScalar(op, 4) # rho
_AssertInputIsScalar(op, 5) # epsilon
grad_shape = op.inputs[6].get_shape().merge_with(accum_shape)
return [grad_shape]
@ops.RegisterShape("ApplyAdagrad")
def _ApplyAdagradShape(op):
"""Shape function for the ApplyAdagrad op."""
var_shape = op.inputs[0].get_shape()
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
_AssertInputIsScalar(op, 2) # lr
grad_shape = op.inputs[3].get_shape().merge_with(accum_shape)
return [grad_shape]
@ops.RegisterShape("ApplyProximalAdagrad")
def _ApplyProximalAdagradShape(op):
"""Shape function for the ApplyProximalAdagrad op."""
var_shape = op.inputs[0].get_shape()
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
_AssertInputIsScalar(op, 2) # lr
_AssertInputIsScalar(op, 3) # l1
_AssertInputIsScalar(op, 4) # l2
grad_shape = op.inputs[5].get_shape().merge_with(accum_shape)
return [grad_shape]
@ops.RegisterShape("ApplyFtrl")
def _ApplyFtrlShape(op):
"""Shape function for the ApplyFtrlOp op."""
var_shape = op.inputs[0].get_shape()
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
linear_shape = op.inputs[2].get_shape().merge_with(accum_shape)
grad_shape = op.inputs[3].get_shape().merge_with(linear_shape)
_AssertInputIsScalar(op, 4) # lr
_AssertInputIsScalar(op, 5) # l1
_AssertInputIsScalar(op, 6) # l2
_AssertInputIsScalar(op, 7) # lr_power
return [grad_shape]
@ops.RegisterShape("ApplyAdagradDA")
def ApplyAdagradDAShape(op):
"""Shape function for the ApplyAdagradDAOp op."""
var_shape = op.inputs[0].get_shape()
g_accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
gg_accum_shape = op.inputs[2].get_shape().merge_with(g_accum_shape)
grad_shape = op.inputs[3].get_shape().merge_with(gg_accum_shape)
_AssertInputIsScalar(op, 4) # lr
_AssertInputIsScalar(op, 5) # l1
_AssertInputIsScalar(op, 6) # l2
_AssertInputIsScalar(op, 7) # global_step
return [grad_shape]
@ops.RegisterShape("ApplyAdam")
def _ApplyAdamShape(op):
"""Shape function for the ApplyAdam op."""
var_shape = op.inputs[0].get_shape()
m_shape = op.inputs[1].get_shape().merge_with(var_shape)
v_shape = op.inputs[2].get_shape().merge_with(m_shape)
_AssertInputIsScalar(op, 3) # beta1_power
_AssertInputIsScalar(op, 4) # beta2_power
_AssertInputIsScalar(op, 5) # lr
_AssertInputIsScalar(op, 6) # beta1
_AssertInputIsScalar(op, 7) # beta2
_AssertInputIsScalar(op, 8) # epsilon
grad_shape = op.inputs[9].get_shape().merge_with(v_shape)
return [grad_shape]
@ops.RegisterShape("ApplyMomentum")
def _ApplyMomentumShape(op):
"""Shape function for the ApplyMomentum op."""
var_shape = op.inputs[0].get_shape()
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
_AssertInputIsScalar(op, 2) # lr
grad_shape = op.inputs[3].get_shape().merge_with(accum_shape)
_AssertInputIsScalar(op, 4) # momentum
return [grad_shape]
@ops.RegisterShape("ApplyRMSProp")
def _ApplyRMSPropShape(op):
"""Shape function for the ApplyRMSProp op."""
var_shape = op.inputs[0].get_shape()
ms_shape = op.inputs[1].get_shape().merge_with(var_shape)
mom_shape = op.inputs[2].get_shape().merge_with(ms_shape)
_AssertInputIsScalar(op, 3) # lr
_AssertInputIsScalar(op, 4) # rho
_AssertInputIsScalar(op, 5) # momentum
_AssertInputIsScalar(op, 6) # epsilon
grad_shape = op.inputs[7].get_shape().merge_with(mom_shape)
return [grad_shape]
@ops.RegisterShape("ApplyGradientDescent")
def _ApplyGradientDescentShape(op):
"""Shape function for the ApplyGradientDescent op."""
var_shape = op.inputs[0].get_shape()
_AssertInputIsScalar(op, 1) # alpha
delta_shape = op.inputs[2].get_shape().merge_with(var_shape)
return [delta_shape]
@ops.RegisterShape("ApplyProximalGradientDescent")
def _ApplyProximalGradientDescentShape(op):
"""Shape function for the ApplyProximalGradientDescent op."""
var_shape = op.inputs[0].get_shape()
_AssertInputIsScalar(op, 1) # alpha
_AssertInputIsScalar(op, 2) # l1
_AssertInputIsScalar(op, 3) # l2
delta_shape = op.inputs[4].get_shape().merge_with(var_shape)
return [delta_shape]
@ops.RegisterShape("SparseApplyProximalGradientDescent")
def _SparseApplyProximalGradientDescentShape(op):
"""Shape function for the SparseApplyGradientDescent op."""
var_shape = op.inputs[0].get_shape()
_AssertInputIsScalar(op, 1) # lr
_AssertInputIsScalar(op, 2) # l1
_AssertInputIsScalar(op, 3) # l2
grad_shape = op.inputs[4].get_shape().merge_with(
tensor_shape.TensorShape([None]).concatenate(var_shape[1:]))
unused_indices_shape = op.inputs[5].get_shape().merge_with(
tensor_shape.vector(grad_shape[0]))
return [var_shape]
@ops.RegisterShape("SparseApplyRMSProp")
def _SparseApplyRMSPropShape(op):
"""Shape function for the SparseApplyRMSProp op."""
var_shape = op.inputs[0].get_shape()
ms_shape = op.inputs[1].get_shape().merge_with(var_shape)
mom_shape = op.inputs[2].get_shape().merge_with(ms_shape)
_AssertInputIsScalar(op, 3) # lr
_AssertInputIsScalar(op, 4) # rho
_AssertInputIsScalar(op, 5) # momentum
_AssertInputIsScalar(op, 6) # epsilon
grad_shape = op.inputs[7].get_shape().merge_with(
tensor_shape.TensorShape([None]).concatenate(mom_shape[1:]))
unused_indices_shape = op.inputs[8].get_shape().merge_with(
tensor_shape.vector(grad_shape[0]))
return [mom_shape]
@ops.RegisterShape("SparseApplyAdadelta")
def _SparseApplyAdadeltaShape(op):
"""Shape function for the SparseApplyAdadelta op."""
var_shape = op.inputs[0].get_shape()
accum_grad_shape = op.inputs[1].get_shape().merge_with(var_shape)
accum_update_shape = op.inputs[2].get_shape().merge_with(accum_grad_shape)
_AssertInputIsScalar(op, 3) # lr
_AssertInputIsScalar(op, 4) # decay_rate
_AssertInputIsScalar(op, 5) # epsilon
grad_shape = op.inputs[6].get_shape().merge_with(
tensor_shape.TensorShape([None]).concatenate(accum_update_shape[1:]))
unused_indices_shape = op.inputs[7].get_shape().merge_with(
tensor_shape.vector(grad_shape[0]))
return [accum_update_shape]
@ops.RegisterShape("SparseApplyAdagrad")
def _SparseApplyAdagradShape(op):
"""Shape function for the SparseApplyAdagrad op."""
var_shape = op.inputs[0].get_shape()
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
_AssertInputIsScalar(op, 2) # lr
grad_shape = op.inputs[3].get_shape().merge_with(
tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
unused_indices_shape = op.inputs[4].get_shape().merge_with(
tensor_shape.vector(grad_shape[0]))
return [accum_shape]
@ops.RegisterShape("SparseApplyProximalAdagrad")
def _SparseApplyProximalAdagradShape(op):
"""Shape function for the SparseApplyProximalAdagrad op."""
var_shape = op.inputs[0].get_shape()
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
_AssertInputIsScalar(op, 2) # lr
_AssertInputIsScalar(op, 3) # l1
_AssertInputIsScalar(op, 4) # l2
grad_shape = op.inputs[5].get_shape().merge_with(
tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
unused_indices_shape = op.inputs[6].get_shape().merge_with(
tensor_shape.vector(grad_shape[0]))
return [accum_shape]
@ops.RegisterShape("SparseApplyFtrl")
def _SparseApplyFtrlShape(op):
"""Shape function for the SparseApplyFtrl op."""
var_shape = op.inputs[0].get_shape()
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
linear_shape = op.inputs[2].get_shape().merge_with(accum_shape)
grad_shape = op.inputs[3].get_shape().merge_with(
tensor_shape.TensorShape([None]).concatenate(linear_shape[1:]))
unused_indices_shape = op.inputs[4].get_shape().merge_with(
tensor_shape.vector(grad_shape[0]))
_AssertInputIsScalar(op, 5) # lr
_AssertInputIsScalar(op, 6) # l1
_AssertInputIsScalar(op, 7) # l2
_AssertInputIsScalar(op, 8) # lr_power
return [linear_shape]
@ops.RegisterShape("SparseApplyAdagradDA")
def _SparseApplyAdagradDAShape(op):
"""Shape function for the SparseApplyAdagradDA op."""
var_shape = op.inputs[0].get_shape()
g_accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
gg_accum_shape = op.inputs[2].get_shape().merge_with(g_accum_shape)
grad_shape = op.inputs[3].get_shape().merge_with(
tensor_shape.TensorShape([None]).concatenate(gg_accum_shape[1:]))
unused_indices_shape = op.inputs[4].get_shape().merge_with(
tensor_shape.vector(grad_shape[0]))
_AssertInputIsScalar(op, 5) # lr
_AssertInputIsScalar(op, 6) # l1
_AssertInputIsScalar(op, 7) # l2
_AssertInputIsScalar(op, 8) # global_step
return [gg_accum_shape]
@ops.RegisterShape("SparseApplyMomentum")
def _SparseApplyMomentumShape(op):
"""Shape function for the SparseApplyMomentum op."""
var_shape = op.inputs[0].get_shape()
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
_AssertInputIsScalar(op, 2) # lr
grad_shape = op.inputs[3].get_shape().merge_with(
tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
unused_indices_shape = op.inputs[4].get_shape().merge_with(
tensor_shape.vector(grad_shape[0]))
_AssertInputIsScalar(op, 5) # momentum
return [accum_shape]
ops.RegisterShape("ApplyAdadelta")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ApplyAdagrad")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ApplyProximalAdagrad")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ApplyFtrl")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ApplyAdagradDA")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ApplyAdam")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ApplyMomentum")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ApplyRMSProp")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ApplyGradientDescent")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ApplyProximalGradientDescent")(
common_shapes.call_cpp_shape_fn)
ops.RegisterShape("SparseApplyProximalGradientDescent")(
common_shapes.call_cpp_shape_fn)
ops.RegisterShape("SparseApplyRMSProp")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("SparseApplyAdadelta")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("SparseApplyAdagrad")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("SparseApplyProximalAdagrad")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("SparseApplyFtrl")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("SparseApplyAdagradDA")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("SparseApplyMomentum")(common_shapes.call_cpp_shape_fn)